sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -11
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +474 -142
- sglang/compile_deep_gemm.py +3 -0
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +10 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +314 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +228 -92
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +294 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +78 -37
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +373 -68
- sglang/srt/disaggregation/prefill.py +53 -49
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +842 -0
- sglang/srt/entrypoints/grpc_server.py +950 -0
- sglang/srt/entrypoints/http_server.py +179 -60
- sglang/srt/entrypoints/openai/protocol.py +265 -29
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +213 -122
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +289 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +17 -8
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +215 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +40 -8
- sglang/srt/layers/attention/flashinfer_backend.py +341 -204
- sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
- sglang/srt/layers/attention/mamba/mamba.py +577 -0
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +180 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
- sglang/srt/layers/moe/ep_moe/layer.py +248 -333
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +83 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +29 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +155 -60
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +191 -56
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +28 -33
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +44 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +55 -0
- sglang/srt/managers/schedule_batch.py +343 -212
- sglang/srt/managers/schedule_policy.py +145 -18
- sglang/srt/managers/scheduler.py +653 -273
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +579 -674
- sglang/srt/managers/tp_worker.py +96 -26
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +9 -2
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +651 -80
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +227 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +93 -48
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +74 -46
- sglang/srt/model_executor/model_runner.py +455 -176
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +10 -4
- sglang/srt/model_loader/loader.py +319 -10
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +161 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +578 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +50 -4
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +55 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +49 -26
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1051 -285
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +98 -29
- sglang/srt/speculative/ngram_info.py +428 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +605 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +451 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +119 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +313 -0
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +140 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +407 -8
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
sglang/test/run_eval.py
CHANGED
@@ -10,11 +10,46 @@ import time
|
|
10
10
|
|
11
11
|
from sglang.test.simple_eval_common import (
|
12
12
|
ChatCompletionSampler,
|
13
|
+
Eval,
|
13
14
|
make_report,
|
14
15
|
set_ulimit,
|
15
16
|
)
|
16
17
|
|
17
18
|
|
19
|
+
def get_thinking_kwargs(args):
|
20
|
+
thinking_mode = getattr(args, "thinking_mode", None)
|
21
|
+
if thinking_mode in THINKING_MODE_CHOICES:
|
22
|
+
if thinking_mode == "deepseek-v3":
|
23
|
+
thinking_param = "thinking"
|
24
|
+
else:
|
25
|
+
thinking_param = "enable_thinking"
|
26
|
+
return {
|
27
|
+
"chat_template_kwargs": {thinking_param: True},
|
28
|
+
}
|
29
|
+
return {}
|
30
|
+
|
31
|
+
|
32
|
+
def run_eval_once(args, base_url: str, eval_obj: Eval) -> dict:
|
33
|
+
# Get thinking kwargs based on user's choice
|
34
|
+
thinking_kwargs = get_thinking_kwargs(args)
|
35
|
+
|
36
|
+
sampler = ChatCompletionSampler(
|
37
|
+
model=args.model,
|
38
|
+
max_tokens=getattr(args, "max_tokens", 2048),
|
39
|
+
base_url=base_url,
|
40
|
+
temperature=getattr(args, "temperature", 0.0),
|
41
|
+
reasoning_effort=getattr(args, "reasoning_effort", None),
|
42
|
+
extra_body=thinking_kwargs,
|
43
|
+
)
|
44
|
+
|
45
|
+
# Run eval
|
46
|
+
tic = time.perf_counter()
|
47
|
+
result = eval_obj(sampler)
|
48
|
+
latency = time.perf_counter() - tic
|
49
|
+
|
50
|
+
return result, latency, sampler
|
51
|
+
|
52
|
+
|
18
53
|
def run_eval(args):
|
19
54
|
set_ulimit()
|
20
55
|
|
@@ -60,21 +95,55 @@ def run_eval(args):
|
|
60
95
|
from sglang.test.simple_eval_humaneval import HumanEval
|
61
96
|
|
62
97
|
eval_obj = HumanEval(args.num_examples, args.num_threads)
|
98
|
+
elif args.eval_name == "longbench_v2":
|
99
|
+
from sglang.test.simple_eval_longbench_v2 import LongBenchV2Eval
|
100
|
+
|
101
|
+
# Default to HuggingFace dataset, can be overridden with --dataset-path
|
102
|
+
data_source = args.dataset_path
|
103
|
+
categories = args.categories.split(",") if args.categories else None
|
104
|
+
|
105
|
+
eval_obj = LongBenchV2Eval(
|
106
|
+
data_source=data_source,
|
107
|
+
num_examples=args.num_examples,
|
108
|
+
num_threads=args.num_threads,
|
109
|
+
categories=categories,
|
110
|
+
max_context_length=getattr(args, "max_context_length", None),
|
111
|
+
min_context_length=getattr(args, "min_context_length", None),
|
112
|
+
)
|
113
|
+
elif args.eval_name == "mmmu":
|
114
|
+
# VLM MMMU evaluation with fixed 100 examples by default
|
115
|
+
from sglang.test.simple_eval_mmmu_vlm import MMMUVLMEval
|
116
|
+
|
117
|
+
eval_obj = MMMUVLMEval(args.num_examples, args.num_threads)
|
63
118
|
else:
|
64
119
|
raise ValueError(f"Invalid eval name: {args.eval_name}")
|
65
120
|
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
temperature=getattr(args, "temperature", 0.0),
|
71
|
-
reasoning_effort=getattr(args, "reasoning_effort", None),
|
72
|
-
)
|
121
|
+
if getattr(args, "repeat", 1) == 1:
|
122
|
+
result, latency, sampler = run_eval_once(args, base_url, eval_obj)
|
123
|
+
else:
|
124
|
+
from concurrent.futures import ThreadPoolExecutor
|
73
125
|
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
126
|
+
executor = ThreadPoolExecutor(max_workers=args.repeat)
|
127
|
+
|
128
|
+
futures = [
|
129
|
+
executor.submit(run_eval_once, args, base_url, eval_obj)
|
130
|
+
for _ in range(args.repeat)
|
131
|
+
]
|
132
|
+
|
133
|
+
scores_repeat = []
|
134
|
+
|
135
|
+
for f in futures:
|
136
|
+
result, latency, sampler = f.result()
|
137
|
+
scores_repeat.append(result.score)
|
138
|
+
|
139
|
+
mean_score = sum(scores_repeat) / len(scores_repeat)
|
140
|
+
scores_repeat = [f"{s:.3f}" for s in scores_repeat]
|
141
|
+
print("=" * 20)
|
142
|
+
print(f"Repeat: {args.repeat}, mean: {mean_score:.3f}")
|
143
|
+
print(f"Scores: {scores_repeat}")
|
144
|
+
print("=" * 20)
|
145
|
+
|
146
|
+
executor.shutdown()
|
78
147
|
|
79
148
|
# Dump reports
|
80
149
|
metrics = result.metrics | {"score": result.score}
|
@@ -94,9 +163,13 @@ def run_eval(args):
|
|
94
163
|
print(f"Total latency: {latency:.3f} s")
|
95
164
|
print(f"Score: {metrics['score']:.3f}")
|
96
165
|
|
166
|
+
if getattr(args, "return_latency", False):
|
167
|
+
return metrics, latency
|
97
168
|
return metrics
|
98
169
|
|
99
170
|
|
171
|
+
THINKING_MODE_CHOICES = ["deepseek-r1", "deepseek-v3", "qwen3"]
|
172
|
+
|
100
173
|
if __name__ == "__main__":
|
101
174
|
parser = argparse.ArgumentParser()
|
102
175
|
parser.add_argument(
|
@@ -118,12 +191,47 @@ if __name__ == "__main__":
|
|
118
191
|
type=str,
|
119
192
|
help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
|
120
193
|
)
|
194
|
+
parser.add_argument(
|
195
|
+
"--repeat", type=int, default=1, help="repeat the evaluation n times"
|
196
|
+
)
|
121
197
|
parser.add_argument("--eval-name", type=str, default="mmlu")
|
122
198
|
parser.add_argument("--num-examples", type=int)
|
123
199
|
parser.add_argument("--num-threads", type=int, default=512)
|
124
200
|
parser.add_argument("--max-tokens", type=int, default=2048)
|
125
201
|
parser.add_argument("--temperature", type=float, default=0.0)
|
126
202
|
parser.add_argument("--reasoning-effort", type=str)
|
203
|
+
parser.add_argument(
|
204
|
+
"--thinking-mode",
|
205
|
+
default=None,
|
206
|
+
type=str,
|
207
|
+
choices=THINKING_MODE_CHOICES,
|
208
|
+
help="Enable thinking mode in Deepseek R1, V3.1/3.2, or Qwen3",
|
209
|
+
)
|
210
|
+
|
211
|
+
# LongBench-v2 specific arguments
|
212
|
+
parser.add_argument(
|
213
|
+
"--dataset-path",
|
214
|
+
type=str,
|
215
|
+
default="THUDM/LongBench-v2",
|
216
|
+
help="Path to dataset file or HuggingFace dataset name for LongBench-v2",
|
217
|
+
)
|
218
|
+
parser.add_argument(
|
219
|
+
"--categories",
|
220
|
+
type=str,
|
221
|
+
default=None,
|
222
|
+
help="Comma-separated list of categories to evaluate for LongBench-v2",
|
223
|
+
)
|
224
|
+
parser.add_argument(
|
225
|
+
"--max-context-length",
|
226
|
+
type=int,
|
227
|
+
help="Maximum context length in characters for LongBench-v2",
|
228
|
+
)
|
229
|
+
parser.add_argument(
|
230
|
+
"--min-context-length",
|
231
|
+
type=int,
|
232
|
+
help="Minimum context length in characters for LongBench-v2",
|
233
|
+
)
|
234
|
+
|
127
235
|
args = parser.parse_args()
|
128
236
|
|
129
237
|
run_eval(args)
|
sglang/test/runners.py
CHANGED
@@ -30,8 +30,8 @@ from transformers import (
|
|
30
30
|
)
|
31
31
|
|
32
32
|
from sglang.srt.entrypoints.engine import Engine
|
33
|
-
from sglang.srt.hf_transformers_utils import get_tokenizer
|
34
33
|
from sglang.srt.utils import load_image
|
34
|
+
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
|
35
35
|
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l
|
36
36
|
|
37
37
|
DEFAULT_PROMPTS = [
|
@@ -505,6 +505,7 @@ class SRTRunner:
|
|
505
505
|
mem_fraction_static: float = 0.65,
|
506
506
|
trust_remote_code: bool = False,
|
507
507
|
speculative_draft_model_path: Optional[str] = None,
|
508
|
+
speculative_draft_model_revision: Optional[str] = None,
|
508
509
|
speculative_algorithm: Optional[str] = None,
|
509
510
|
speculative_num_steps: Optional[int] = None,
|
510
511
|
speculative_eagle_topk: Optional[int] = None,
|
@@ -526,6 +527,9 @@ class SRTRunner:
|
|
526
527
|
spec_kwargs = {}
|
527
528
|
if speculative_draft_model_path:
|
528
529
|
spec_kwargs["speculative_draft_model_path"] = speculative_draft_model_path
|
530
|
+
spec_kwargs["speculative_draft_model_revision"] = (
|
531
|
+
speculative_draft_model_revision
|
532
|
+
)
|
529
533
|
spec_kwargs["speculative_algorithm"] = speculative_algorithm
|
530
534
|
spec_kwargs["speculative_num_steps"] = speculative_num_steps
|
531
535
|
spec_kwargs["speculative_eagle_topk"] = speculative_eagle_topk
|
@@ -93,6 +93,7 @@ class ChatCompletionSampler(SamplerBase):
|
|
93
93
|
temperature: float = 0.0,
|
94
94
|
reasoning_effort: Optional[str] = None,
|
95
95
|
max_tokens: int = 2048,
|
96
|
+
extra_body: Optional[Dict[str, Any]] = None,
|
96
97
|
):
|
97
98
|
self.client = OpenAI(base_url=base_url, http_client=LargerHttpxClient())
|
98
99
|
|
@@ -104,9 +105,10 @@ class ChatCompletionSampler(SamplerBase):
|
|
104
105
|
self.temperature = temperature
|
105
106
|
self.max_tokens = max_tokens
|
106
107
|
self.reasoning_effort = reasoning_effort
|
108
|
+
self.extra_body = extra_body
|
107
109
|
self.image_format = "url"
|
108
110
|
print(
|
109
|
-
f"ChatCompletionSampler initialized with {self.system_message=} {self.temperature=} {self.max_tokens=} {self.reasoning_effort=}"
|
111
|
+
f"ChatCompletionSampler initialized with {self.system_message=} {self.temperature=} {self.max_tokens=} {self.reasoning_effort=} {self.extra_body=}"
|
110
112
|
)
|
111
113
|
|
112
114
|
def _handle_image(
|
@@ -136,7 +138,7 @@ class ChatCompletionSampler(SamplerBase):
|
|
136
138
|
self._pack_message("system", self.system_message)
|
137
139
|
] + message_list
|
138
140
|
trial = 0
|
139
|
-
while
|
141
|
+
while trial < 6: # 126 seconds in total
|
140
142
|
try:
|
141
143
|
response = self.client.chat.completions.create(
|
142
144
|
model=self.model,
|
@@ -144,6 +146,7 @@ class ChatCompletionSampler(SamplerBase):
|
|
144
146
|
temperature=self.temperature,
|
145
147
|
max_tokens=self.max_tokens,
|
146
148
|
reasoning_effort=self.reasoning_effort,
|
149
|
+
extra_body=self.extra_body,
|
147
150
|
)
|
148
151
|
return response.choices[0].message.content
|
149
152
|
# NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are rerunning MMMU
|
@@ -0,0 +1,332 @@
|
|
1
|
+
# Adapted from https://github.com/openai/simple-evals/
|
2
|
+
|
3
|
+
"""
|
4
|
+
LongBench v2: Towards Deeper Understanding and Reasoning on Realistic Long-Context Multitasks
|
5
|
+
Yushi Bai, Shangqing Tu, Jiajie Zhang, Hao Peng, Xiaozhi Wang, Xin Lv, Shulin Cao, Jiazheng Xu, Lei Hou, Yuxiao Dong, Jie Tang, Juanzi Li
|
6
|
+
https://arxiv.org/abs/2412.15204
|
7
|
+
"""
|
8
|
+
|
9
|
+
import csv
|
10
|
+
import json
|
11
|
+
import os
|
12
|
+
import re
|
13
|
+
from typing import Any, Dict, List, Optional
|
14
|
+
|
15
|
+
from sglang.test import simple_eval_common as common
|
16
|
+
from sglang.test.simple_eval_common import (
|
17
|
+
ANSWER_PATTERN_MULTICHOICE,
|
18
|
+
HTML_JINJA,
|
19
|
+
Eval,
|
20
|
+
EvalResult,
|
21
|
+
SamplerBase,
|
22
|
+
SingleEvalResult,
|
23
|
+
)
|
24
|
+
|
25
|
+
# LongBench-v2 task categories
|
26
|
+
TASK_CATEGORIES = {
|
27
|
+
"single_document_qa",
|
28
|
+
"multi_document_qa",
|
29
|
+
"long_in_context_learning",
|
30
|
+
"long_dialogue_history",
|
31
|
+
"code_repo_understanding",
|
32
|
+
"long_structured_data",
|
33
|
+
}
|
34
|
+
|
35
|
+
DEFAULT_DATASET = "THUDM/LongBench-v2"
|
36
|
+
DEFAULT_DATASET_SPLIT = "train"
|
37
|
+
|
38
|
+
|
39
|
+
def format_longbench_v2_question(row: dict) -> str:
|
40
|
+
"""Format a LongBench-v2 question using the official template."""
|
41
|
+
context = row.get("context", "")
|
42
|
+
question = row.get("question", "")
|
43
|
+
|
44
|
+
# Handle both standard format (A, B, C, D) and alternative format (choices list)
|
45
|
+
if "choices" in row:
|
46
|
+
choices = row["choices"]
|
47
|
+
choice_A = choices[0] if len(choices) > 0 else ""
|
48
|
+
choice_B = choices[1] if len(choices) > 1 else ""
|
49
|
+
choice_C = choices[2] if len(choices) > 2 else ""
|
50
|
+
choice_D = choices[3] if len(choices) > 3 else ""
|
51
|
+
else:
|
52
|
+
choice_A = row.get("A", row.get("choice_A", ""))
|
53
|
+
choice_B = row.get("B", row.get("choice_B", ""))
|
54
|
+
choice_C = row.get("C", row.get("choice_C", ""))
|
55
|
+
choice_D = row.get("D", row.get("choice_D", ""))
|
56
|
+
|
57
|
+
# Official LongBench-v2 template
|
58
|
+
prompt = f"""{context.strip()}
|
59
|
+
|
60
|
+
What is the correct answer to this question: {question.strip()}
|
61
|
+
Choices:
|
62
|
+
(A) {choice_A.strip()}
|
63
|
+
(B) {choice_B.strip()}
|
64
|
+
(C) {choice_C.strip()}
|
65
|
+
(D) {choice_D.strip()}
|
66
|
+
|
67
|
+
The correct answer is"""
|
68
|
+
|
69
|
+
return prompt
|
70
|
+
|
71
|
+
|
72
|
+
def extract_longbench_v2_answer(response: str) -> Optional[str]:
|
73
|
+
"""Extract answer from model response using official LongBench-v2 method."""
|
74
|
+
response = response.replace("*", "")
|
75
|
+
|
76
|
+
# First try: "The correct answer is (A)"
|
77
|
+
match = re.search(r"The correct answer is \(([A-D])\)", response, re.IGNORECASE)
|
78
|
+
if match:
|
79
|
+
return match.group(1).upper()
|
80
|
+
|
81
|
+
# Second try: "The correct answer is A"
|
82
|
+
match = re.search(r"The correct answer is ([A-D])", response, re.IGNORECASE)
|
83
|
+
if match:
|
84
|
+
return match.group(1).upper()
|
85
|
+
|
86
|
+
# Fallback: Standard SGLang multichoice pattern
|
87
|
+
match = re.search(ANSWER_PATTERN_MULTICHOICE, response)
|
88
|
+
if match:
|
89
|
+
return match.group(1).upper()
|
90
|
+
|
91
|
+
# Generic fallback when model says "answer is A"
|
92
|
+
match = re.search(r"answer\s+is\s*\(?([A-D])\)?", response, re.IGNORECASE)
|
93
|
+
if match:
|
94
|
+
return match.group(1).upper()
|
95
|
+
|
96
|
+
return None
|
97
|
+
|
98
|
+
|
99
|
+
class LongBenchV2Eval(Eval):
|
100
|
+
"""
|
101
|
+
Evaluation utility for LongBench-v2 dataset.
|
102
|
+
|
103
|
+
LongBench-v2 is designed to assess the ability of LLMs to handle long-context problems
|
104
|
+
requiring deep understanding and reasoning across real-world multitasks.
|
105
|
+
"""
|
106
|
+
|
107
|
+
def __init__(
|
108
|
+
self,
|
109
|
+
data_source: str = DEFAULT_DATASET,
|
110
|
+
num_examples: Optional[int] = None,
|
111
|
+
num_threads: int = 1,
|
112
|
+
n_repeats: int = 1,
|
113
|
+
categories: Optional[List[str]] = None,
|
114
|
+
max_context_length: Optional[int] = None,
|
115
|
+
min_context_length: Optional[int] = None,
|
116
|
+
):
|
117
|
+
"""
|
118
|
+
Initialize LongBench-v2 evaluation.
|
119
|
+
|
120
|
+
Args:
|
121
|
+
data_source: HuggingFace dataset name, local file path (CSV/JSON)
|
122
|
+
num_examples: Number of examples to evaluate (None for all)
|
123
|
+
num_threads: Number of threads for parallel processing
|
124
|
+
n_repeats: Number of times to repeat evaluation for error bars
|
125
|
+
categories: List of task categories to include (None for all)
|
126
|
+
max_context_length: Maximum context length in characters
|
127
|
+
min_context_length: Minimum context length in characters
|
128
|
+
"""
|
129
|
+
# Load dataset based on data source type
|
130
|
+
examples = self._load_dataset(data_source)
|
131
|
+
|
132
|
+
# Apply filtering
|
133
|
+
if categories:
|
134
|
+
examples = [ex for ex in examples if ex.get("category") in categories]
|
135
|
+
|
136
|
+
if min_context_length or max_context_length:
|
137
|
+
examples = self._filter_by_context_length(
|
138
|
+
examples, min_context_length, max_context_length
|
139
|
+
)
|
140
|
+
|
141
|
+
# Sample examples if specified
|
142
|
+
if num_examples:
|
143
|
+
assert n_repeats == 1, "n_repeats only supported when not sampling examples"
|
144
|
+
examples = examples[: min(num_examples, len(examples))]
|
145
|
+
|
146
|
+
# Repeat examples for multiple runs
|
147
|
+
examples = examples * n_repeats
|
148
|
+
|
149
|
+
if not examples:
|
150
|
+
raise ValueError(
|
151
|
+
"No examples available for LongBench-v2 evaluation after filtering"
|
152
|
+
)
|
153
|
+
|
154
|
+
self.examples = examples
|
155
|
+
self.n_repeats = n_repeats
|
156
|
+
self.num_threads = num_threads
|
157
|
+
|
158
|
+
print(f"Loaded {len(self.examples)} examples from LongBench-v2")
|
159
|
+
if categories:
|
160
|
+
print(f"Filtered to categories: {categories}")
|
161
|
+
if min_context_length or max_context_length:
|
162
|
+
print(
|
163
|
+
f"Context length filter: {min_context_length}-{max_context_length} characters"
|
164
|
+
)
|
165
|
+
|
166
|
+
def _load_dataset(self, data_source: str) -> List[Dict[str, Any]]:
|
167
|
+
"""Load dataset from HuggingFace hub or local files."""
|
168
|
+
|
169
|
+
if not data_source:
|
170
|
+
data_source = DEFAULT_DATASET
|
171
|
+
|
172
|
+
if os.path.exists(data_source):
|
173
|
+
raw_examples = self._load_local_file(data_source)
|
174
|
+
else:
|
175
|
+
raw_examples = self._load_hf_dataset(data_source)
|
176
|
+
|
177
|
+
return [self._normalize_example(example) for example in raw_examples]
|
178
|
+
|
179
|
+
def _load_local_file(self, path: str) -> List[Dict[str, Any]]:
|
180
|
+
"""Load examples from a local CSV/JSON/JSONL file."""
|
181
|
+
|
182
|
+
suffix = os.path.splitext(path)[1].lower()
|
183
|
+
if suffix in {".json", ".jsonl"}:
|
184
|
+
with open(path, "r", encoding="utf-8") as fh:
|
185
|
+
if suffix == ".jsonl":
|
186
|
+
data = [json.loads(line) for line in fh if line.strip()]
|
187
|
+
else:
|
188
|
+
data = json.load(fh)
|
189
|
+
elif suffix == ".csv":
|
190
|
+
with open(path, "r", encoding="utf-8") as fh:
|
191
|
+
reader = csv.DictReader(fh)
|
192
|
+
data = list(reader)
|
193
|
+
else:
|
194
|
+
# Try JSON, then CSV as fallback
|
195
|
+
try:
|
196
|
+
with open(path, "r", encoding="utf-8") as fh:
|
197
|
+
data = json.load(fh)
|
198
|
+
except json.JSONDecodeError:
|
199
|
+
with open(path, "r", encoding="utf-8") as fh:
|
200
|
+
reader = csv.DictReader(fh)
|
201
|
+
data = list(reader)
|
202
|
+
|
203
|
+
if isinstance(data, dict):
|
204
|
+
data = data.get("data", [])
|
205
|
+
|
206
|
+
if not isinstance(data, list):
|
207
|
+
raise ValueError("Expected list of examples from local file")
|
208
|
+
|
209
|
+
return data
|
210
|
+
|
211
|
+
def _load_hf_dataset(self, identifier: str) -> List[Dict[str, Any]]:
|
212
|
+
"""Load the dataset from HuggingFace Hub."""
|
213
|
+
|
214
|
+
parts = identifier.split(":", maxsplit=1)
|
215
|
+
dataset_name = parts[0]
|
216
|
+
split = parts[1] if len(parts) == 2 else DEFAULT_DATASET_SPLIT
|
217
|
+
|
218
|
+
try:
|
219
|
+
from datasets import load_dataset # type: ignore
|
220
|
+
except ImportError as exc:
|
221
|
+
raise ImportError(
|
222
|
+
"Please install the 'datasets' package to load LongBench-v2 from HuggingFace: pip install datasets"
|
223
|
+
) from exc
|
224
|
+
|
225
|
+
dataset = load_dataset(dataset_name, split=split)
|
226
|
+
return [dict(row) for row in dataset]
|
227
|
+
|
228
|
+
def _normalize_example(self, example: Dict[str, Any]) -> Dict[str, Any]:
|
229
|
+
"""Ensure each example exposes the expected keys."""
|
230
|
+
|
231
|
+
normalized = dict(example)
|
232
|
+
|
233
|
+
for letter in ["A", "B", "C", "D"]:
|
234
|
+
choice_key = f"choice_{letter}"
|
235
|
+
if letter not in normalized and choice_key in normalized:
|
236
|
+
normalized[letter] = normalized[choice_key]
|
237
|
+
|
238
|
+
if "category" not in normalized and "domain" in normalized:
|
239
|
+
normalized["category"] = normalized["domain"]
|
240
|
+
|
241
|
+
answer = normalized.get("answer")
|
242
|
+
if isinstance(answer, str):
|
243
|
+
normalized["answer"] = answer.strip().upper()
|
244
|
+
elif isinstance(answer, int) and 0 <= answer < 4:
|
245
|
+
normalized["answer"] = ["A", "B", "C", "D"][answer]
|
246
|
+
|
247
|
+
return normalized
|
248
|
+
|
249
|
+
def _filter_by_context_length(
|
250
|
+
self,
|
251
|
+
examples: List[Dict[str, Any]],
|
252
|
+
min_length: Optional[int],
|
253
|
+
max_length: Optional[int],
|
254
|
+
) -> List[Dict[str, Any]]:
|
255
|
+
"""Filter examples by context length measured in characters."""
|
256
|
+
filtered = []
|
257
|
+
for example in examples:
|
258
|
+
context = example.get("context", "")
|
259
|
+
context_length = len(context)
|
260
|
+
|
261
|
+
if min_length is not None and context_length < min_length:
|
262
|
+
continue
|
263
|
+
if max_length is not None and context_length > max_length:
|
264
|
+
continue
|
265
|
+
|
266
|
+
filtered.append(example)
|
267
|
+
|
268
|
+
return filtered
|
269
|
+
|
270
|
+
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
271
|
+
"""Run the evaluation."""
|
272
|
+
|
273
|
+
def fn(row: dict):
|
274
|
+
# Format the question using official template
|
275
|
+
formatted_question = format_longbench_v2_question(row)
|
276
|
+
|
277
|
+
prompt_messages = [
|
278
|
+
sampler._pack_message(content=formatted_question, role="user")
|
279
|
+
]
|
280
|
+
|
281
|
+
# Get model response
|
282
|
+
response_text = sampler(prompt_messages)
|
283
|
+
if response_text is None:
|
284
|
+
response_text = ""
|
285
|
+
|
286
|
+
# Extract answer using official method
|
287
|
+
extracted_answer = extract_longbench_v2_answer(response_text)
|
288
|
+
|
289
|
+
# Get correct answer
|
290
|
+
correct_answer = row.get("answer", "")
|
291
|
+
if isinstance(correct_answer, str):
|
292
|
+
correct_answer = correct_answer.strip().upper()
|
293
|
+
elif isinstance(correct_answer, int) and 0 <= correct_answer < 4:
|
294
|
+
correct_answer = ["A", "B", "C", "D"][correct_answer]
|
295
|
+
|
296
|
+
# Calculate score
|
297
|
+
score = 1.0 if extracted_answer == correct_answer else 0.0
|
298
|
+
|
299
|
+
# Generate HTML report
|
300
|
+
html = common.jinja_env.from_string(HTML_JINJA).render(
|
301
|
+
prompt_messages=prompt_messages,
|
302
|
+
next_message=dict(content=response_text, role="assistant"),
|
303
|
+
score=score,
|
304
|
+
correct_answer=correct_answer,
|
305
|
+
extracted_answer=extracted_answer,
|
306
|
+
)
|
307
|
+
|
308
|
+
# Build conversation
|
309
|
+
convo = prompt_messages + [dict(content=response_text, role="assistant")]
|
310
|
+
|
311
|
+
# Prepare metrics
|
312
|
+
metrics = {"chars": len(response_text)}
|
313
|
+
|
314
|
+
# Add category-specific metrics
|
315
|
+
category = row.get("category", row.get("domain", "unknown"))
|
316
|
+
if category in TASK_CATEGORIES:
|
317
|
+
metrics[category] = score
|
318
|
+
|
319
|
+
difficulty = row.get("difficulty")
|
320
|
+
if isinstance(difficulty, str) and difficulty:
|
321
|
+
metrics[f"difficulty_{difficulty.lower()}"] = score
|
322
|
+
|
323
|
+
return SingleEvalResult(
|
324
|
+
html=html,
|
325
|
+
score=score,
|
326
|
+
convo=convo,
|
327
|
+
metrics=metrics,
|
328
|
+
)
|
329
|
+
|
330
|
+
# Run evaluation with progress tracking
|
331
|
+
results = common.map_with_progress(fn, self.examples, self.num_threads)
|
332
|
+
return common.aggregate_results(results)
|