sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -11
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +474 -142
- sglang/compile_deep_gemm.py +3 -0
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +10 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +314 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +228 -92
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +294 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +78 -37
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +373 -68
- sglang/srt/disaggregation/prefill.py +53 -49
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +842 -0
- sglang/srt/entrypoints/grpc_server.py +950 -0
- sglang/srt/entrypoints/http_server.py +179 -60
- sglang/srt/entrypoints/openai/protocol.py +265 -29
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +213 -122
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +289 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +17 -8
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +215 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +40 -8
- sglang/srt/layers/attention/flashinfer_backend.py +341 -204
- sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
- sglang/srt/layers/attention/mamba/mamba.py +577 -0
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +180 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
- sglang/srt/layers/moe/ep_moe/layer.py +248 -333
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +83 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +29 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +155 -60
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +191 -56
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +28 -33
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +44 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +55 -0
- sglang/srt/managers/schedule_batch.py +343 -212
- sglang/srt/managers/schedule_policy.py +145 -18
- sglang/srt/managers/scheduler.py +653 -273
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +579 -674
- sglang/srt/managers/tp_worker.py +96 -26
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +9 -2
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +651 -80
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +227 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +93 -48
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +74 -46
- sglang/srt/model_executor/model_runner.py +455 -176
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +10 -4
- sglang/srt/model_loader/loader.py +319 -10
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +161 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +578 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +50 -4
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +55 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +49 -26
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1051 -285
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +98 -29
- sglang/srt/speculative/ngram_info.py +428 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +605 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +451 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +119 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +313 -0
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +140 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +407 -8
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
sglang/srt/weight_sync/utils.py
CHANGED
@@ -6,7 +6,7 @@ from torch.distributed.device_mesh import DeviceMesh
|
|
6
6
|
from torch.distributed.tensor import DTensor
|
7
7
|
|
8
8
|
from sglang.srt.entrypoints.engine import Engine
|
9
|
-
from sglang.srt.managers.
|
9
|
+
from sglang.srt.managers.io_struct import UpdateWeightsFromTensorReqInput
|
10
10
|
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
|
11
11
|
from sglang.srt.utils import MultiprocessingSerializer
|
12
12
|
|
@@ -33,7 +33,7 @@ async def update_weights(
|
|
33
33
|
"""
|
34
34
|
infer_tp_size = device_mesh[device_mesh_key].mesh.size()[0]
|
35
35
|
infer_tp_rank = device_mesh[device_mesh_key].get_local_rank()
|
36
|
-
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
36
|
+
from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
|
37
37
|
|
38
38
|
monkey_patch_torch_reductions()
|
39
39
|
|
@@ -41,6 +41,10 @@ DEFAULT_CONFIG = {
|
|
41
41
|
"v_head_dim": 512,
|
42
42
|
"num_kv_heads": 1,
|
43
43
|
"layer_id": 0,
|
44
|
+
"tp_q_head_num": 128,
|
45
|
+
"tp_k_head_num": 128,
|
46
|
+
"prefill_head_dim": 192,
|
47
|
+
"prefill_v_head_dim": 128,
|
44
48
|
}
|
45
49
|
|
46
50
|
ROPE_BASE = 10000
|
@@ -92,7 +96,7 @@ TEST_CASES = {
|
|
92
96
|
"description": "Medium-scale batch",
|
93
97
|
},
|
94
98
|
],
|
95
|
-
"
|
99
|
+
"output_match": [
|
96
100
|
{
|
97
101
|
"name": "single_fp16",
|
98
102
|
"batch_size": 1,
|
@@ -322,7 +326,7 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
322
326
|
config.update(test_case)
|
323
327
|
return config
|
324
328
|
|
325
|
-
def _create_model_components(self, config):
|
329
|
+
def _create_model_components(self, config, is_prefill=False):
|
326
330
|
"""Create model runners, backends, and layer for testing."""
|
327
331
|
# Create model runners
|
328
332
|
model_runner_trtllm = MockModelRunner(config)
|
@@ -332,14 +336,23 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
332
336
|
trtllm_backend = TRTLLMMLABackend(model_runner_trtllm)
|
333
337
|
reference_backend = FlashInferMLAAttnBackend(model_runner_reference)
|
334
338
|
|
339
|
+
head_dim = (
|
340
|
+
config["kv_lora_rank"] + config["qk_rope_head_dim"]
|
341
|
+
if not is_prefill
|
342
|
+
else config["prefill_head_dim"]
|
343
|
+
)
|
344
|
+
v_head_dim = (
|
345
|
+
config["v_head_dim"] if not is_prefill else config["prefill_v_head_dim"]
|
346
|
+
)
|
347
|
+
|
335
348
|
# Create RadixAttention layer
|
336
349
|
layer = RadixAttention(
|
337
350
|
num_heads=config["num_attention_heads"],
|
338
|
-
head_dim=
|
351
|
+
head_dim=head_dim,
|
339
352
|
scaling=model_runner_trtllm.model_config.scaling,
|
340
353
|
num_kv_heads=config["num_kv_heads"],
|
341
354
|
layer_id=config["layer_id"],
|
342
|
-
v_head_dim=
|
355
|
+
v_head_dim=v_head_dim,
|
343
356
|
prefix="attn_mqa",
|
344
357
|
)
|
345
358
|
|
@@ -524,7 +537,7 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
524
537
|
"""Test that TRTLLM and FlashInfer MLA backends produce matching outputs."""
|
525
538
|
print(f"\nRunning decode output matching tests...")
|
526
539
|
|
527
|
-
for test_case in TEST_CASES["
|
540
|
+
for test_case in TEST_CASES["output_match"]:
|
528
541
|
with self.subTest(test_case=test_case["name"]):
|
529
542
|
print(f" Testing {test_case['name']}: {test_case['description']}")
|
530
543
|
|
@@ -1099,6 +1112,157 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
1099
1112
|
self.assertIsNotNone(metadata_3.block_kv_indices)
|
1100
1113
|
self.assertEqual(metadata_3.block_kv_indices.shape[0], config["batch_size"])
|
1101
1114
|
|
1115
|
+
def test_prefill_output_match_self_attention(self):
|
1116
|
+
"""Test prefill (forward) behavior of TRTLLM MLA backend vs reference."""
|
1117
|
+
print(f"\nRunning prefill output tests...")
|
1118
|
+
|
1119
|
+
for test_case in TEST_CASES["output_match"][:2]: # Just a subset for speed
|
1120
|
+
with self.subTest(test_case=test_case["name"]):
|
1121
|
+
print(
|
1122
|
+
f"Prefill Testing {test_case['name']}: {test_case['description']}"
|
1123
|
+
)
|
1124
|
+
|
1125
|
+
config = self._merge_config(test_case)
|
1126
|
+
batch_size = config["batch_size"]
|
1127
|
+
max_seq_len = config["max_seq_len"]
|
1128
|
+
|
1129
|
+
# Create components
|
1130
|
+
(
|
1131
|
+
model_runner_trtllm,
|
1132
|
+
model_runner_reference,
|
1133
|
+
trtllm_backend,
|
1134
|
+
reference_backend,
|
1135
|
+
layer,
|
1136
|
+
) = self._create_model_components(config, is_prefill=True)
|
1137
|
+
|
1138
|
+
# Prefill uses full sequences
|
1139
|
+
seq_lens = torch.full(
|
1140
|
+
(batch_size,), max_seq_len, device=config["device"]
|
1141
|
+
)
|
1142
|
+
|
1143
|
+
def _create_forward_batch_prefill(
|
1144
|
+
batch_size,
|
1145
|
+
seq_lens,
|
1146
|
+
extend_prefix_lens,
|
1147
|
+
backend,
|
1148
|
+
model_runner,
|
1149
|
+
config,
|
1150
|
+
):
|
1151
|
+
"""Create a forward batch for the given backend."""
|
1152
|
+
|
1153
|
+
fb = ForwardBatch(
|
1154
|
+
batch_size=batch_size,
|
1155
|
+
input_ids=torch.randint(
|
1156
|
+
0, 100, (batch_size, 1), device=config["device"]
|
1157
|
+
),
|
1158
|
+
out_cache_loc=torch.arange(batch_size, device=config["device"]),
|
1159
|
+
seq_lens_sum=int(seq_lens.sum().item()),
|
1160
|
+
extend_prefix_lens=extend_prefix_lens,
|
1161
|
+
extend_prefix_lens_cpu=extend_prefix_lens.cpu().int().tolist(),
|
1162
|
+
extend_seq_lens_cpu=(seq_lens - extend_prefix_lens)
|
1163
|
+
.cpu()
|
1164
|
+
.int()
|
1165
|
+
.tolist(),
|
1166
|
+
forward_mode=ForwardMode.EXTEND,
|
1167
|
+
req_pool_indices=torch.arange(
|
1168
|
+
batch_size, device=config["device"]
|
1169
|
+
),
|
1170
|
+
seq_lens=seq_lens,
|
1171
|
+
seq_lens_cpu=seq_lens.cpu(),
|
1172
|
+
attn_attend_prefix_cache=False,
|
1173
|
+
mha_return_lse=False,
|
1174
|
+
attn_backend=backend,
|
1175
|
+
)
|
1176
|
+
fb.req_to_token_pool = model_runner.req_to_token_pool
|
1177
|
+
fb.token_to_kv_pool = model_runner.token_to_kv_pool
|
1178
|
+
|
1179
|
+
# Add position information for RoPE
|
1180
|
+
fb.positions = torch.arange(batch_size, device=config["device"])
|
1181
|
+
|
1182
|
+
return fb
|
1183
|
+
|
1184
|
+
# Create forward batches
|
1185
|
+
fb_trtllm = _create_forward_batch_prefill(
|
1186
|
+
batch_size,
|
1187
|
+
seq_lens.clone(),
|
1188
|
+
torch.zeros(batch_size, device=config["device"], dtype=torch.int32),
|
1189
|
+
trtllm_backend,
|
1190
|
+
model_runner_trtllm,
|
1191
|
+
config,
|
1192
|
+
)
|
1193
|
+
fb_reference = _create_forward_batch_prefill(
|
1194
|
+
batch_size,
|
1195
|
+
seq_lens.clone(),
|
1196
|
+
torch.zeros(batch_size, device=config["device"], dtype=torch.int32),
|
1197
|
+
reference_backend,
|
1198
|
+
model_runner_reference,
|
1199
|
+
config,
|
1200
|
+
)
|
1201
|
+
|
1202
|
+
# Initialize metadata for both backends
|
1203
|
+
trtllm_backend.init_forward_metadata(fb_trtllm)
|
1204
|
+
reference_backend.init_forward_metadata(fb_reference)
|
1205
|
+
|
1206
|
+
# Create Q, K, V tensors for prefill
|
1207
|
+
torch.manual_seed(config["seed_qkv"])
|
1208
|
+
|
1209
|
+
def _create_qkv_tensors_prefill(
|
1210
|
+
batch_size, seq_len, config, dtype_override=None
|
1211
|
+
):
|
1212
|
+
"""Create Q, K, V tensors for prefill, using config for head_num and head_dim."""
|
1213
|
+
device = config["device"]
|
1214
|
+
dtype = dtype_override or config["dtype"]
|
1215
|
+
|
1216
|
+
total_tokens = batch_size * seq_len
|
1217
|
+
|
1218
|
+
tp_q_head_num = config["tp_q_head_num"]
|
1219
|
+
tp_k_head_num = config["tp_k_head_num"]
|
1220
|
+
head_dim = config["prefill_head_dim"]
|
1221
|
+
v_head_dim = config["prefill_v_head_dim"]
|
1222
|
+
|
1223
|
+
q = torch.randn(
|
1224
|
+
(total_tokens, tp_q_head_num * head_dim),
|
1225
|
+
dtype=dtype,
|
1226
|
+
device=device,
|
1227
|
+
)
|
1228
|
+
k = torch.randn(
|
1229
|
+
(total_tokens, tp_k_head_num * head_dim),
|
1230
|
+
dtype=dtype,
|
1231
|
+
device=device,
|
1232
|
+
)
|
1233
|
+
v = torch.randn(
|
1234
|
+
(total_tokens, tp_k_head_num * v_head_dim),
|
1235
|
+
dtype=dtype,
|
1236
|
+
device=device,
|
1237
|
+
)
|
1238
|
+
|
1239
|
+
# Reshape as requested
|
1240
|
+
q = q.view(-1, tp_q_head_num, head_dim)
|
1241
|
+
k = k.view(-1, tp_k_head_num, head_dim)
|
1242
|
+
v = v.view(-1, tp_k_head_num, v_head_dim)
|
1243
|
+
|
1244
|
+
return q, k, v
|
1245
|
+
|
1246
|
+
q, k, v = _create_qkv_tensors_prefill(batch_size, max_seq_len, config)
|
1247
|
+
# Run prefill on both backends
|
1248
|
+
out_trtllm = trtllm_backend.forward_extend(
|
1249
|
+
q, k, v, layer, fb_trtllm, False
|
1250
|
+
).view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
1251
|
+
out_reference = reference_backend.forward_extend(
|
1252
|
+
q, k, v, layer, fb_reference, False
|
1253
|
+
)
|
1254
|
+
|
1255
|
+
tolerance = config.get("tolerance", 1e-2)
|
1256
|
+
comparison_passed = compare_outputs(
|
1257
|
+
out_trtllm, out_reference, tolerance=tolerance
|
1258
|
+
)
|
1259
|
+
self.assertTrue(
|
1260
|
+
comparison_passed,
|
1261
|
+
f"TRTLLM and Reference prefill outputs differ beyond tolerance. "
|
1262
|
+
f"Config: {test_case['name']}, "
|
1263
|
+
f"Max diff: {(out_trtllm - out_reference).abs().max().item()}",
|
1264
|
+
)
|
1265
|
+
|
1102
1266
|
|
1103
1267
|
if __name__ == "__main__":
|
1104
1268
|
unittest.main()
|
@@ -0,0 +1,57 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
|
4
|
+
|
5
|
+
class DummyModel(nn.Module):
|
6
|
+
def __init__(self, d_in=2048, n_heads=128, softmax_scale=0.5):
|
7
|
+
super().__init__()
|
8
|
+
self.weights_proj = nn.Linear(d_in, 1024)
|
9
|
+
self.n_heads = n_heads
|
10
|
+
self.softmax_scale = softmax_scale
|
11
|
+
|
12
|
+
def _get_logits_head_gate_orig(self, x: torch.Tensor, q_scale: torch.Tensor):
|
13
|
+
weights = self.weights_proj(x)
|
14
|
+
weights = weights * self.n_heads**-0.5
|
15
|
+
q_scale = q_scale.unsqueeze(1) # (B,1,1)
|
16
|
+
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
|
17
|
+
return weights
|
18
|
+
|
19
|
+
def _get_logits_head_gate_opt(self, x: torch.Tensor, q_scale: torch.Tensor):
|
20
|
+
weights = self.weights_proj(x)
|
21
|
+
q_scale = q_scale.unsqueeze(1) # (B,1,1)
|
22
|
+
scale_const = self.n_heads**-0.5 * q_scale * self.softmax_scale # (B,1,1)
|
23
|
+
weights = weights.unsqueeze(-1) * scale_const # (B,1024,1)
|
24
|
+
return weights
|
25
|
+
|
26
|
+
|
27
|
+
def main():
|
28
|
+
torch.manual_seed(0)
|
29
|
+
model = DummyModel(d_in=2048, n_heads=128, softmax_scale=0.5)
|
30
|
+
x = torch.randn(128, 2048) # batch=128, d_in=2048
|
31
|
+
q_scale = torch.randn(128, 1)
|
32
|
+
|
33
|
+
import time
|
34
|
+
|
35
|
+
start = time.time()
|
36
|
+
for _ in range(1000):
|
37
|
+
out_orig = model._get_logits_head_gate_orig(x, q_scale)
|
38
|
+
print("Original version time:", time.time() - start)
|
39
|
+
|
40
|
+
start = time.time()
|
41
|
+
for _ in range(1000):
|
42
|
+
out_opt = model._get_logits_head_gate_opt(x, q_scale)
|
43
|
+
print("Optimized version time:", time.time() - start)
|
44
|
+
|
45
|
+
print("Difference:", (out_orig - out_opt).abs().max().item())
|
46
|
+
assert torch.allclose(out_orig, out_opt), "Mismatch between original and optimized"
|
47
|
+
|
48
|
+
|
49
|
+
if __name__ == "__main__":
|
50
|
+
main()
|
51
|
+
|
52
|
+
|
53
|
+
"""
|
54
|
+
Original version time: 0.49235057830810547
|
55
|
+
Optimized version time: 0.4087331295013428
|
56
|
+
Difference: 1.4901161193847656e-08
|
57
|
+
"""
|
@@ -0,0 +1 @@
|
|
1
|
+
"""LongBench-v2 auxiliary utilities and validation scripts."""
|
@@ -0,0 +1,238 @@
|
|
1
|
+
"""
|
2
|
+
Test cases for LongBench-v2 evaluation utility.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import json
|
6
|
+
import os
|
7
|
+
import tempfile
|
8
|
+
|
9
|
+
from sglang.test.simple_eval_longbench_v2 import (
|
10
|
+
LongBenchV2Eval,
|
11
|
+
extract_longbench_v2_answer,
|
12
|
+
format_longbench_v2_question,
|
13
|
+
)
|
14
|
+
|
15
|
+
|
16
|
+
def test_format_longbench_v2_question():
|
17
|
+
"""Test the official LongBench-v2 question formatting."""
|
18
|
+
sample_row = {
|
19
|
+
"context": "This is a sample context about environmental issues.",
|
20
|
+
"question": "What is the main theme?",
|
21
|
+
"A": "Technology",
|
22
|
+
"B": "Environment",
|
23
|
+
"C": "Economics",
|
24
|
+
"D": "Politics",
|
25
|
+
"answer": "B",
|
26
|
+
}
|
27
|
+
|
28
|
+
formatted = format_longbench_v2_question(sample_row)
|
29
|
+
|
30
|
+
# Verify official template structure
|
31
|
+
assert "This is a sample context about environmental issues." in formatted
|
32
|
+
assert (
|
33
|
+
"What is the correct answer to this question: What is the main theme?"
|
34
|
+
in formatted
|
35
|
+
)
|
36
|
+
assert "(A) Technology" in formatted
|
37
|
+
assert "(B) Environment" in formatted
|
38
|
+
assert "(C) Economics" in formatted
|
39
|
+
assert "(D) Politics" in formatted
|
40
|
+
assert "The correct answer is" in formatted
|
41
|
+
print("✓ Question formatting works correctly")
|
42
|
+
|
43
|
+
|
44
|
+
def test_extract_longbench_v2_answer():
|
45
|
+
"""Test the official LongBench-v2 answer extraction."""
|
46
|
+
|
47
|
+
# Test official format: "The correct answer is (A)"
|
48
|
+
response1 = "After analyzing the context, The correct answer is (B)."
|
49
|
+
assert extract_longbench_v2_answer(response1) == "B"
|
50
|
+
|
51
|
+
# Test alternative format: "The correct answer is A"
|
52
|
+
response2 = "Based on the evidence, The correct answer is C."
|
53
|
+
assert extract_longbench_v2_answer(response2) == "C"
|
54
|
+
|
55
|
+
# Test with asterisks
|
56
|
+
response3 = "*The correct answer is (D)*"
|
57
|
+
assert extract_longbench_v2_answer(response3) == "D"
|
58
|
+
|
59
|
+
# Test fallback to standard pattern
|
60
|
+
response4 = "I think the answer is A."
|
61
|
+
assert extract_longbench_v2_answer(response4) == "A"
|
62
|
+
|
63
|
+
# Test no answer
|
64
|
+
response5 = "I'm not sure about this."
|
65
|
+
assert extract_longbench_v2_answer(response5) is None
|
66
|
+
|
67
|
+
print("✓ Answer extraction works correctly")
|
68
|
+
|
69
|
+
|
70
|
+
def test_longbench_v2_eval_initialization():
|
71
|
+
"""Test LongBench-v2 evaluation class initialization."""
|
72
|
+
|
73
|
+
# Create a temporary JSON file with sample data
|
74
|
+
sample_data = [
|
75
|
+
{
|
76
|
+
"_id": "test_001",
|
77
|
+
"domain": "single_document_qa",
|
78
|
+
"question": "What is X?",
|
79
|
+
"choice_A": "Option A1",
|
80
|
+
"choice_B": "Option B1",
|
81
|
+
"choice_C": "Option C1",
|
82
|
+
"choice_D": "Option D1",
|
83
|
+
"answer": "A",
|
84
|
+
"context": "Context 1",
|
85
|
+
},
|
86
|
+
{
|
87
|
+
"_id": "test_002",
|
88
|
+
"domain": "multi_document_qa",
|
89
|
+
"question": "What is Y?",
|
90
|
+
"A": "Option A2",
|
91
|
+
"B": "Option B2",
|
92
|
+
"C": "Option C2",
|
93
|
+
"D": "Option D2",
|
94
|
+
"answer": "B",
|
95
|
+
"context": "Context 2",
|
96
|
+
},
|
97
|
+
]
|
98
|
+
|
99
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
100
|
+
json.dump(sample_data, f)
|
101
|
+
temp_file = f.name
|
102
|
+
|
103
|
+
try:
|
104
|
+
# Test initialization with new data_source parameter
|
105
|
+
eval_instance = LongBenchV2Eval(data_source=temp_file, num_examples=1)
|
106
|
+
assert len(eval_instance.examples) == 1
|
107
|
+
first_example = eval_instance.examples[0]
|
108
|
+
assert first_example.get("category") in {
|
109
|
+
"single_document_qa",
|
110
|
+
"multi_document_qa",
|
111
|
+
}
|
112
|
+
assert first_example.get("A") in {"Option A1", "Option A2"}
|
113
|
+
print("✓ Evaluation class initialization works correctly")
|
114
|
+
|
115
|
+
finally:
|
116
|
+
os.unlink(temp_file)
|
117
|
+
|
118
|
+
|
119
|
+
def test_category_filtering():
|
120
|
+
"""Ensure category filtering keeps only requested domains."""
|
121
|
+
|
122
|
+
sample_data = [
|
123
|
+
{
|
124
|
+
"_id": "test_001",
|
125
|
+
"domain": "single_document_qa",
|
126
|
+
"question": "What is X?",
|
127
|
+
"choice_A": "Option A1",
|
128
|
+
"choice_B": "Option B1",
|
129
|
+
"choice_C": "Option C1",
|
130
|
+
"choice_D": "Option D1",
|
131
|
+
"answer": "A",
|
132
|
+
"context": "Context 1",
|
133
|
+
},
|
134
|
+
{
|
135
|
+
"_id": "test_002",
|
136
|
+
"domain": "multi_document_qa",
|
137
|
+
"question": "What is Y?",
|
138
|
+
"choice_A": "Option A2",
|
139
|
+
"choice_B": "Option B2",
|
140
|
+
"choice_C": "Option C2",
|
141
|
+
"choice_D": "Option D2",
|
142
|
+
"answer": "B",
|
143
|
+
"context": "Context 2",
|
144
|
+
},
|
145
|
+
]
|
146
|
+
|
147
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
148
|
+
json.dump(sample_data, f)
|
149
|
+
temp_file = f.name
|
150
|
+
|
151
|
+
try:
|
152
|
+
eval_instance = LongBenchV2Eval(
|
153
|
+
data_source=temp_file,
|
154
|
+
categories=["multi_document_qa"],
|
155
|
+
)
|
156
|
+
assert len(eval_instance.examples) == 1
|
157
|
+
assert eval_instance.examples[0]["category"] == "multi_document_qa"
|
158
|
+
print("✓ Category filtering works correctly")
|
159
|
+
finally:
|
160
|
+
os.unlink(temp_file)
|
161
|
+
|
162
|
+
|
163
|
+
def test_difficulty_metrics():
|
164
|
+
"""Validate that difficulty-specific metrics are recorded."""
|
165
|
+
|
166
|
+
sample_data = [
|
167
|
+
{
|
168
|
+
"_id": "easy_001",
|
169
|
+
"domain": "single_document_qa",
|
170
|
+
"difficulty": "easy",
|
171
|
+
"question": "Easy question?",
|
172
|
+
"choice_A": "Correct",
|
173
|
+
"choice_B": "Wrong",
|
174
|
+
"choice_C": "Wrong",
|
175
|
+
"choice_D": "Wrong",
|
176
|
+
"answer": "A",
|
177
|
+
"context": "Easy context",
|
178
|
+
},
|
179
|
+
{
|
180
|
+
"_id": "hard_001",
|
181
|
+
"domain": "single_document_qa",
|
182
|
+
"difficulty": "hard",
|
183
|
+
"question": "Hard question?",
|
184
|
+
"choice_A": "Wrong",
|
185
|
+
"choice_B": "Correct",
|
186
|
+
"choice_C": "Wrong",
|
187
|
+
"choice_D": "Wrong",
|
188
|
+
"answer": "B",
|
189
|
+
"context": "Hard context",
|
190
|
+
},
|
191
|
+
]
|
192
|
+
|
193
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
194
|
+
json.dump(sample_data, f)
|
195
|
+
temp_file = f.name
|
196
|
+
|
197
|
+
class FixedSampler: # noqa: D401 - simple helper
|
198
|
+
"""Mock sampler returning the correct answer based on question text."""
|
199
|
+
|
200
|
+
def _pack_message(self, content: str, role: str):
|
201
|
+
return {"content": content, "role": role}
|
202
|
+
|
203
|
+
def __call__(self, messages):
|
204
|
+
prompt = messages[0]["content"]
|
205
|
+
if "Easy question" in prompt:
|
206
|
+
return "The correct answer is (A)"
|
207
|
+
return "The correct answer is (B)"
|
208
|
+
|
209
|
+
try:
|
210
|
+
eval_instance = LongBenchV2Eval(data_source=temp_file, num_threads=1)
|
211
|
+
result = eval_instance(FixedSampler())
|
212
|
+
|
213
|
+
assert result.metrics.get("difficulty_easy") == 1.0
|
214
|
+
assert result.metrics.get("difficulty_hard") == 1.0
|
215
|
+
print("✓ Difficulty metrics recorded correctly")
|
216
|
+
finally:
|
217
|
+
os.unlink(temp_file)
|
218
|
+
|
219
|
+
|
220
|
+
def main():
|
221
|
+
"""Run all tests."""
|
222
|
+
print("Testing simplified LongBench-v2 evaluation utility...\n")
|
223
|
+
|
224
|
+
test_format_longbench_v2_question()
|
225
|
+
test_extract_longbench_v2_answer()
|
226
|
+
test_longbench_v2_eval_initialization()
|
227
|
+
test_category_filtering()
|
228
|
+
test_difficulty_metrics()
|
229
|
+
|
230
|
+
print("\n" + "=" * 50)
|
231
|
+
print("✅ ALL TESTS PASSED!")
|
232
|
+
print("The simplified implementation follows SGLang patterns")
|
233
|
+
print("while maintaining LongBench-v2 compatibility.")
|
234
|
+
print("=" * 50)
|
235
|
+
|
236
|
+
|
237
|
+
if __name__ == "__main__":
|
238
|
+
main()
|