sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
sglang/srt/weight_sync/utils.py
CHANGED
@@ -6,7 +6,7 @@ from torch.distributed.device_mesh import DeviceMesh
|
|
6
6
|
from torch.distributed.tensor import DTensor
|
7
7
|
|
8
8
|
from sglang.srt.entrypoints.engine import Engine
|
9
|
-
from sglang.srt.managers.
|
9
|
+
from sglang.srt.managers.io_struct import UpdateWeightsFromTensorReqInput
|
10
10
|
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
|
11
11
|
from sglang.srt.utils import MultiprocessingSerializer
|
12
12
|
|
@@ -33,7 +33,7 @@ async def update_weights(
|
|
33
33
|
"""
|
34
34
|
infer_tp_size = device_mesh[device_mesh_key].mesh.size()[0]
|
35
35
|
infer_tp_rank = device_mesh[device_mesh_key].get_local_rank()
|
36
|
-
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
36
|
+
from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
|
37
37
|
|
38
38
|
monkey_patch_torch_reductions()
|
39
39
|
|
@@ -41,6 +41,10 @@ DEFAULT_CONFIG = {
|
|
41
41
|
"v_head_dim": 512,
|
42
42
|
"num_kv_heads": 1,
|
43
43
|
"layer_id": 0,
|
44
|
+
"tp_q_head_num": 128,
|
45
|
+
"tp_k_head_num": 128,
|
46
|
+
"prefill_head_dim": 192,
|
47
|
+
"prefill_v_head_dim": 128,
|
44
48
|
}
|
45
49
|
|
46
50
|
ROPE_BASE = 10000
|
@@ -92,7 +96,7 @@ TEST_CASES = {
|
|
92
96
|
"description": "Medium-scale batch",
|
93
97
|
},
|
94
98
|
],
|
95
|
-
"
|
99
|
+
"output_match": [
|
96
100
|
{
|
97
101
|
"name": "single_fp16",
|
98
102
|
"batch_size": 1,
|
@@ -322,7 +326,7 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
322
326
|
config.update(test_case)
|
323
327
|
return config
|
324
328
|
|
325
|
-
def _create_model_components(self, config):
|
329
|
+
def _create_model_components(self, config, is_prefill=False):
|
326
330
|
"""Create model runners, backends, and layer for testing."""
|
327
331
|
# Create model runners
|
328
332
|
model_runner_trtllm = MockModelRunner(config)
|
@@ -332,14 +336,23 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
332
336
|
trtllm_backend = TRTLLMMLABackend(model_runner_trtllm)
|
333
337
|
reference_backend = FlashInferMLAAttnBackend(model_runner_reference)
|
334
338
|
|
339
|
+
head_dim = (
|
340
|
+
config["kv_lora_rank"] + config["qk_rope_head_dim"]
|
341
|
+
if not is_prefill
|
342
|
+
else config["prefill_head_dim"]
|
343
|
+
)
|
344
|
+
v_head_dim = (
|
345
|
+
config["v_head_dim"] if not is_prefill else config["prefill_v_head_dim"]
|
346
|
+
)
|
347
|
+
|
335
348
|
# Create RadixAttention layer
|
336
349
|
layer = RadixAttention(
|
337
350
|
num_heads=config["num_attention_heads"],
|
338
|
-
head_dim=
|
351
|
+
head_dim=head_dim,
|
339
352
|
scaling=model_runner_trtllm.model_config.scaling,
|
340
353
|
num_kv_heads=config["num_kv_heads"],
|
341
354
|
layer_id=config["layer_id"],
|
342
|
-
v_head_dim=
|
355
|
+
v_head_dim=v_head_dim,
|
343
356
|
prefix="attn_mqa",
|
344
357
|
)
|
345
358
|
|
@@ -524,7 +537,7 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
524
537
|
"""Test that TRTLLM and FlashInfer MLA backends produce matching outputs."""
|
525
538
|
print(f"\nRunning decode output matching tests...")
|
526
539
|
|
527
|
-
for test_case in TEST_CASES["
|
540
|
+
for test_case in TEST_CASES["output_match"]:
|
528
541
|
with self.subTest(test_case=test_case["name"]):
|
529
542
|
print(f" Testing {test_case['name']}: {test_case['description']}")
|
530
543
|
|
@@ -1099,6 +1112,157 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
1099
1112
|
self.assertIsNotNone(metadata_3.block_kv_indices)
|
1100
1113
|
self.assertEqual(metadata_3.block_kv_indices.shape[0], config["batch_size"])
|
1101
1114
|
|
1115
|
+
def test_prefill_output_match_self_attention(self):
|
1116
|
+
"""Test prefill (forward) behavior of TRTLLM MLA backend vs reference."""
|
1117
|
+
print(f"\nRunning prefill output tests...")
|
1118
|
+
|
1119
|
+
for test_case in TEST_CASES["output_match"][:2]: # Just a subset for speed
|
1120
|
+
with self.subTest(test_case=test_case["name"]):
|
1121
|
+
print(
|
1122
|
+
f"Prefill Testing {test_case['name']}: {test_case['description']}"
|
1123
|
+
)
|
1124
|
+
|
1125
|
+
config = self._merge_config(test_case)
|
1126
|
+
batch_size = config["batch_size"]
|
1127
|
+
max_seq_len = config["max_seq_len"]
|
1128
|
+
|
1129
|
+
# Create components
|
1130
|
+
(
|
1131
|
+
model_runner_trtllm,
|
1132
|
+
model_runner_reference,
|
1133
|
+
trtllm_backend,
|
1134
|
+
reference_backend,
|
1135
|
+
layer,
|
1136
|
+
) = self._create_model_components(config, is_prefill=True)
|
1137
|
+
|
1138
|
+
# Prefill uses full sequences
|
1139
|
+
seq_lens = torch.full(
|
1140
|
+
(batch_size,), max_seq_len, device=config["device"]
|
1141
|
+
)
|
1142
|
+
|
1143
|
+
def _create_forward_batch_prefill(
|
1144
|
+
batch_size,
|
1145
|
+
seq_lens,
|
1146
|
+
extend_prefix_lens,
|
1147
|
+
backend,
|
1148
|
+
model_runner,
|
1149
|
+
config,
|
1150
|
+
):
|
1151
|
+
"""Create a forward batch for the given backend."""
|
1152
|
+
|
1153
|
+
fb = ForwardBatch(
|
1154
|
+
batch_size=batch_size,
|
1155
|
+
input_ids=torch.randint(
|
1156
|
+
0, 100, (batch_size, 1), device=config["device"]
|
1157
|
+
),
|
1158
|
+
out_cache_loc=torch.arange(batch_size, device=config["device"]),
|
1159
|
+
seq_lens_sum=int(seq_lens.sum().item()),
|
1160
|
+
extend_prefix_lens=extend_prefix_lens,
|
1161
|
+
extend_prefix_lens_cpu=extend_prefix_lens.cpu().int().tolist(),
|
1162
|
+
extend_seq_lens_cpu=(seq_lens - extend_prefix_lens)
|
1163
|
+
.cpu()
|
1164
|
+
.int()
|
1165
|
+
.tolist(),
|
1166
|
+
forward_mode=ForwardMode.EXTEND,
|
1167
|
+
req_pool_indices=torch.arange(
|
1168
|
+
batch_size, device=config["device"]
|
1169
|
+
),
|
1170
|
+
seq_lens=seq_lens,
|
1171
|
+
seq_lens_cpu=seq_lens.cpu(),
|
1172
|
+
attn_attend_prefix_cache=False,
|
1173
|
+
mha_return_lse=False,
|
1174
|
+
attn_backend=backend,
|
1175
|
+
)
|
1176
|
+
fb.req_to_token_pool = model_runner.req_to_token_pool
|
1177
|
+
fb.token_to_kv_pool = model_runner.token_to_kv_pool
|
1178
|
+
|
1179
|
+
# Add position information for RoPE
|
1180
|
+
fb.positions = torch.arange(batch_size, device=config["device"])
|
1181
|
+
|
1182
|
+
return fb
|
1183
|
+
|
1184
|
+
# Create forward batches
|
1185
|
+
fb_trtllm = _create_forward_batch_prefill(
|
1186
|
+
batch_size,
|
1187
|
+
seq_lens.clone(),
|
1188
|
+
torch.zeros(batch_size, device=config["device"], dtype=torch.int32),
|
1189
|
+
trtllm_backend,
|
1190
|
+
model_runner_trtllm,
|
1191
|
+
config,
|
1192
|
+
)
|
1193
|
+
fb_reference = _create_forward_batch_prefill(
|
1194
|
+
batch_size,
|
1195
|
+
seq_lens.clone(),
|
1196
|
+
torch.zeros(batch_size, device=config["device"], dtype=torch.int32),
|
1197
|
+
reference_backend,
|
1198
|
+
model_runner_reference,
|
1199
|
+
config,
|
1200
|
+
)
|
1201
|
+
|
1202
|
+
# Initialize metadata for both backends
|
1203
|
+
trtllm_backend.init_forward_metadata(fb_trtllm)
|
1204
|
+
reference_backend.init_forward_metadata(fb_reference)
|
1205
|
+
|
1206
|
+
# Create Q, K, V tensors for prefill
|
1207
|
+
torch.manual_seed(config["seed_qkv"])
|
1208
|
+
|
1209
|
+
def _create_qkv_tensors_prefill(
|
1210
|
+
batch_size, seq_len, config, dtype_override=None
|
1211
|
+
):
|
1212
|
+
"""Create Q, K, V tensors for prefill, using config for head_num and head_dim."""
|
1213
|
+
device = config["device"]
|
1214
|
+
dtype = dtype_override or config["dtype"]
|
1215
|
+
|
1216
|
+
total_tokens = batch_size * seq_len
|
1217
|
+
|
1218
|
+
tp_q_head_num = config["tp_q_head_num"]
|
1219
|
+
tp_k_head_num = config["tp_k_head_num"]
|
1220
|
+
head_dim = config["prefill_head_dim"]
|
1221
|
+
v_head_dim = config["prefill_v_head_dim"]
|
1222
|
+
|
1223
|
+
q = torch.randn(
|
1224
|
+
(total_tokens, tp_q_head_num * head_dim),
|
1225
|
+
dtype=dtype,
|
1226
|
+
device=device,
|
1227
|
+
)
|
1228
|
+
k = torch.randn(
|
1229
|
+
(total_tokens, tp_k_head_num * head_dim),
|
1230
|
+
dtype=dtype,
|
1231
|
+
device=device,
|
1232
|
+
)
|
1233
|
+
v = torch.randn(
|
1234
|
+
(total_tokens, tp_k_head_num * v_head_dim),
|
1235
|
+
dtype=dtype,
|
1236
|
+
device=device,
|
1237
|
+
)
|
1238
|
+
|
1239
|
+
# Reshape as requested
|
1240
|
+
q = q.view(-1, tp_q_head_num, head_dim)
|
1241
|
+
k = k.view(-1, tp_k_head_num, head_dim)
|
1242
|
+
v = v.view(-1, tp_k_head_num, v_head_dim)
|
1243
|
+
|
1244
|
+
return q, k, v
|
1245
|
+
|
1246
|
+
q, k, v = _create_qkv_tensors_prefill(batch_size, max_seq_len, config)
|
1247
|
+
# Run prefill on both backends
|
1248
|
+
out_trtllm = trtllm_backend.forward_extend(
|
1249
|
+
q, k, v, layer, fb_trtllm, False
|
1250
|
+
).view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
1251
|
+
out_reference = reference_backend.forward_extend(
|
1252
|
+
q, k, v, layer, fb_reference, False
|
1253
|
+
)
|
1254
|
+
|
1255
|
+
tolerance = config.get("tolerance", 1e-2)
|
1256
|
+
comparison_passed = compare_outputs(
|
1257
|
+
out_trtllm, out_reference, tolerance=tolerance
|
1258
|
+
)
|
1259
|
+
self.assertTrue(
|
1260
|
+
comparison_passed,
|
1261
|
+
f"TRTLLM and Reference prefill outputs differ beyond tolerance. "
|
1262
|
+
f"Config: {test_case['name']}, "
|
1263
|
+
f"Max diff: {(out_trtllm - out_reference).abs().max().item()}",
|
1264
|
+
)
|
1265
|
+
|
1102
1266
|
|
1103
1267
|
if __name__ == "__main__":
|
1104
1268
|
unittest.main()
|
@@ -0,0 +1,57 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
|
4
|
+
|
5
|
+
class DummyModel(nn.Module):
|
6
|
+
def __init__(self, d_in=2048, n_heads=128, softmax_scale=0.5):
|
7
|
+
super().__init__()
|
8
|
+
self.weights_proj = nn.Linear(d_in, 1024)
|
9
|
+
self.n_heads = n_heads
|
10
|
+
self.softmax_scale = softmax_scale
|
11
|
+
|
12
|
+
def _get_logits_head_gate_orig(self, x: torch.Tensor, q_scale: torch.Tensor):
|
13
|
+
weights = self.weights_proj(x)
|
14
|
+
weights = weights * self.n_heads**-0.5
|
15
|
+
q_scale = q_scale.unsqueeze(1) # (B,1,1)
|
16
|
+
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
|
17
|
+
return weights
|
18
|
+
|
19
|
+
def _get_logits_head_gate_opt(self, x: torch.Tensor, q_scale: torch.Tensor):
|
20
|
+
weights = self.weights_proj(x)
|
21
|
+
q_scale = q_scale.unsqueeze(1) # (B,1,1)
|
22
|
+
scale_const = self.n_heads**-0.5 * q_scale * self.softmax_scale # (B,1,1)
|
23
|
+
weights = weights.unsqueeze(-1) * scale_const # (B,1024,1)
|
24
|
+
return weights
|
25
|
+
|
26
|
+
|
27
|
+
def main():
|
28
|
+
torch.manual_seed(0)
|
29
|
+
model = DummyModel(d_in=2048, n_heads=128, softmax_scale=0.5)
|
30
|
+
x = torch.randn(128, 2048) # batch=128, d_in=2048
|
31
|
+
q_scale = torch.randn(128, 1)
|
32
|
+
|
33
|
+
import time
|
34
|
+
|
35
|
+
start = time.time()
|
36
|
+
for _ in range(1000):
|
37
|
+
out_orig = model._get_logits_head_gate_orig(x, q_scale)
|
38
|
+
print("Original version time:", time.time() - start)
|
39
|
+
|
40
|
+
start = time.time()
|
41
|
+
for _ in range(1000):
|
42
|
+
out_opt = model._get_logits_head_gate_opt(x, q_scale)
|
43
|
+
print("Optimized version time:", time.time() - start)
|
44
|
+
|
45
|
+
print("Difference:", (out_orig - out_opt).abs().max().item())
|
46
|
+
assert torch.allclose(out_orig, out_opt), "Mismatch between original and optimized"
|
47
|
+
|
48
|
+
|
49
|
+
if __name__ == "__main__":
|
50
|
+
main()
|
51
|
+
|
52
|
+
|
53
|
+
"""
|
54
|
+
Original version time: 0.49235057830810547
|
55
|
+
Optimized version time: 0.4087331295013428
|
56
|
+
Difference: 1.4901161193847656e-08
|
57
|
+
"""
|
sglang/test/run_eval.py
CHANGED
@@ -10,11 +10,46 @@ import time
|
|
10
10
|
|
11
11
|
from sglang.test.simple_eval_common import (
|
12
12
|
ChatCompletionSampler,
|
13
|
+
Eval,
|
13
14
|
make_report,
|
14
15
|
set_ulimit,
|
15
16
|
)
|
16
17
|
|
17
18
|
|
19
|
+
def get_thinking_kwargs(args):
|
20
|
+
thinking_mode = getattr(args, "thinking_mode", None)
|
21
|
+
if thinking_mode in THINKING_MODE_CHOICES:
|
22
|
+
if thinking_mode == "deepseek-v3":
|
23
|
+
thinking_param = "thinking"
|
24
|
+
else:
|
25
|
+
thinking_param = "enable_thinking"
|
26
|
+
return {
|
27
|
+
"chat_template_kwargs": {thinking_param: True},
|
28
|
+
}
|
29
|
+
return {}
|
30
|
+
|
31
|
+
|
32
|
+
def run_eval_once(args, base_url: str, eval_obj: Eval) -> dict:
|
33
|
+
# Get thinking kwargs based on user's choice
|
34
|
+
thinking_kwargs = get_thinking_kwargs(args)
|
35
|
+
|
36
|
+
sampler = ChatCompletionSampler(
|
37
|
+
model=args.model,
|
38
|
+
max_tokens=getattr(args, "max_tokens", 2048),
|
39
|
+
base_url=base_url,
|
40
|
+
temperature=getattr(args, "temperature", 0.0),
|
41
|
+
reasoning_effort=getattr(args, "reasoning_effort", None),
|
42
|
+
extra_body=thinking_kwargs,
|
43
|
+
)
|
44
|
+
|
45
|
+
# Run eval
|
46
|
+
tic = time.perf_counter()
|
47
|
+
result = eval_obj(sampler)
|
48
|
+
latency = time.perf_counter() - tic
|
49
|
+
|
50
|
+
return result, latency, sampler
|
51
|
+
|
52
|
+
|
18
53
|
def run_eval(args):
|
19
54
|
set_ulimit()
|
20
55
|
|
@@ -60,21 +95,40 @@ def run_eval(args):
|
|
60
95
|
from sglang.test.simple_eval_humaneval import HumanEval
|
61
96
|
|
62
97
|
eval_obj = HumanEval(args.num_examples, args.num_threads)
|
98
|
+
elif args.eval_name == "mmmu":
|
99
|
+
# VLM MMMU evaluation with fixed 100 examples by default
|
100
|
+
from sglang.test.simple_eval_mmmu_vlm import MMMUVLMEval
|
101
|
+
|
102
|
+
eval_obj = MMMUVLMEval(args.num_examples, args.num_threads)
|
63
103
|
else:
|
64
104
|
raise ValueError(f"Invalid eval name: {args.eval_name}")
|
65
105
|
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
temperature=getattr(args, "temperature", 0.0),
|
71
|
-
reasoning_effort=getattr(args, "reasoning_effort", None),
|
72
|
-
)
|
106
|
+
if getattr(args, "repeat", 1) == 1:
|
107
|
+
result, latency, sampler = run_eval_once(args, base_url, eval_obj)
|
108
|
+
else:
|
109
|
+
from concurrent.futures import ThreadPoolExecutor
|
73
110
|
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
111
|
+
executor = ThreadPoolExecutor(max_workers=args.repeat)
|
112
|
+
|
113
|
+
futures = [
|
114
|
+
executor.submit(run_eval_once, args, base_url, eval_obj)
|
115
|
+
for _ in range(args.repeat)
|
116
|
+
]
|
117
|
+
|
118
|
+
scores_repeat = []
|
119
|
+
|
120
|
+
for f in futures:
|
121
|
+
result, latency, sampler = f.result()
|
122
|
+
scores_repeat.append(result.score)
|
123
|
+
|
124
|
+
mean_score = sum(scores_repeat) / len(scores_repeat)
|
125
|
+
scores_repeat = [f"{s:.3f}" for s in scores_repeat]
|
126
|
+
print("=" * 20)
|
127
|
+
print(f"Repeat: {args.repeat}, mean: {mean_score:.3f}")
|
128
|
+
print(f"Scores: {scores_repeat}")
|
129
|
+
print("=" * 20)
|
130
|
+
|
131
|
+
executor.shutdown()
|
78
132
|
|
79
133
|
# Dump reports
|
80
134
|
metrics = result.metrics | {"score": result.score}
|
@@ -94,9 +148,13 @@ def run_eval(args):
|
|
94
148
|
print(f"Total latency: {latency:.3f} s")
|
95
149
|
print(f"Score: {metrics['score']:.3f}")
|
96
150
|
|
151
|
+
if getattr(args, "return_latency", False):
|
152
|
+
return metrics, latency
|
97
153
|
return metrics
|
98
154
|
|
99
155
|
|
156
|
+
THINKING_MODE_CHOICES = ["deepseek-r1", "deepseek-v3", "qwen3"]
|
157
|
+
|
100
158
|
if __name__ == "__main__":
|
101
159
|
parser = argparse.ArgumentParser()
|
102
160
|
parser.add_argument(
|
@@ -118,12 +176,22 @@ if __name__ == "__main__":
|
|
118
176
|
type=str,
|
119
177
|
help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
|
120
178
|
)
|
179
|
+
parser.add_argument(
|
180
|
+
"--repeat", type=int, default=1, help="repeat the evaluation n times"
|
181
|
+
)
|
121
182
|
parser.add_argument("--eval-name", type=str, default="mmlu")
|
122
183
|
parser.add_argument("--num-examples", type=int)
|
123
184
|
parser.add_argument("--num-threads", type=int, default=512)
|
124
185
|
parser.add_argument("--max-tokens", type=int, default=2048)
|
125
186
|
parser.add_argument("--temperature", type=float, default=0.0)
|
126
187
|
parser.add_argument("--reasoning-effort", type=str)
|
188
|
+
parser.add_argument(
|
189
|
+
"--thinking-mode",
|
190
|
+
default=None,
|
191
|
+
type=str,
|
192
|
+
choices=THINKING_MODE_CHOICES,
|
193
|
+
help="Enable thinking mode in Deepseek R1, V3.1/3.2, or Qwen3",
|
194
|
+
)
|
127
195
|
args = parser.parse_args()
|
128
196
|
|
129
197
|
run_eval(args)
|
sglang/test/runners.py
CHANGED
@@ -30,8 +30,8 @@ from transformers import (
|
|
30
30
|
)
|
31
31
|
|
32
32
|
from sglang.srt.entrypoints.engine import Engine
|
33
|
-
from sglang.srt.hf_transformers_utils import get_tokenizer
|
34
33
|
from sglang.srt.utils import load_image
|
34
|
+
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
|
35
35
|
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l
|
36
36
|
|
37
37
|
DEFAULT_PROMPTS = [
|
@@ -505,6 +505,7 @@ class SRTRunner:
|
|
505
505
|
mem_fraction_static: float = 0.65,
|
506
506
|
trust_remote_code: bool = False,
|
507
507
|
speculative_draft_model_path: Optional[str] = None,
|
508
|
+
speculative_draft_model_revision: Optional[str] = None,
|
508
509
|
speculative_algorithm: Optional[str] = None,
|
509
510
|
speculative_num_steps: Optional[int] = None,
|
510
511
|
speculative_eagle_topk: Optional[int] = None,
|
@@ -526,6 +527,9 @@ class SRTRunner:
|
|
526
527
|
spec_kwargs = {}
|
527
528
|
if speculative_draft_model_path:
|
528
529
|
spec_kwargs["speculative_draft_model_path"] = speculative_draft_model_path
|
530
|
+
spec_kwargs["speculative_draft_model_revision"] = (
|
531
|
+
speculative_draft_model_revision
|
532
|
+
)
|
529
533
|
spec_kwargs["speculative_algorithm"] = speculative_algorithm
|
530
534
|
spec_kwargs["speculative_num_steps"] = speculative_num_steps
|
531
535
|
spec_kwargs["speculative_eagle_topk"] = speculative_eagle_topk
|
@@ -93,6 +93,7 @@ class ChatCompletionSampler(SamplerBase):
|
|
93
93
|
temperature: float = 0.0,
|
94
94
|
reasoning_effort: Optional[str] = None,
|
95
95
|
max_tokens: int = 2048,
|
96
|
+
extra_body: Optional[Dict[str, Any]] = None,
|
96
97
|
):
|
97
98
|
self.client = OpenAI(base_url=base_url, http_client=LargerHttpxClient())
|
98
99
|
|
@@ -104,9 +105,10 @@ class ChatCompletionSampler(SamplerBase):
|
|
104
105
|
self.temperature = temperature
|
105
106
|
self.max_tokens = max_tokens
|
106
107
|
self.reasoning_effort = reasoning_effort
|
108
|
+
self.extra_body = extra_body
|
107
109
|
self.image_format = "url"
|
108
110
|
print(
|
109
|
-
f"ChatCompletionSampler initialized with {self.system_message=} {self.temperature=} {self.max_tokens=} {self.reasoning_effort=}"
|
111
|
+
f"ChatCompletionSampler initialized with {self.system_message=} {self.temperature=} {self.max_tokens=} {self.reasoning_effort=} {self.extra_body=}"
|
110
112
|
)
|
111
113
|
|
112
114
|
def _handle_image(
|
@@ -136,7 +138,7 @@ class ChatCompletionSampler(SamplerBase):
|
|
136
138
|
self._pack_message("system", self.system_message)
|
137
139
|
] + message_list
|
138
140
|
trial = 0
|
139
|
-
while
|
141
|
+
while trial < 6: # 126 seconds in total
|
140
142
|
try:
|
141
143
|
response = self.client.chat.completions.create(
|
142
144
|
model=self.model,
|
@@ -144,6 +146,7 @@ class ChatCompletionSampler(SamplerBase):
|
|
144
146
|
temperature=self.temperature,
|
145
147
|
max_tokens=self.max_tokens,
|
146
148
|
reasoning_effort=self.reasoning_effort,
|
149
|
+
extra_body=self.extra_body,
|
147
150
|
)
|
148
151
|
return response.choices[0].message.content
|
149
152
|
# NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are rerunning MMMU
|