sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +330 -156
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +8 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +134 -23
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +70 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +66 -66
- sglang/srt/entrypoints/grpc_server.py +431 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +120 -8
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +42 -4
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +18 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +4 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +12 -8
- sglang/srt/layers/attention/flashinfer_backend.py +248 -21
- sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +11 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +45 -15
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +147 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
- sglang/srt/layers/moe/ep_moe/layer.py +119 -397
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +17 -1
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +5 -30
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +673 -16
- sglang/srt/layers/sampler.py +36 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +66 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +399 -499
- sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +378 -90
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +435 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +28 -23
- sglang/srt/model_executor/model_runner.py +379 -139
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +273 -98
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +14 -37
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +5 -5
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3_moe.py +19 -35
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +6 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +577 -73
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +38 -28
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +192 -47
- sglang/srt/utils/hf_transformers_utils.py +40 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +232 -99
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -16,7 +16,6 @@
|
|
|
16
16
|
import asyncio
|
|
17
17
|
import copy
|
|
18
18
|
import dataclasses
|
|
19
|
-
import json
|
|
20
19
|
import logging
|
|
21
20
|
import math
|
|
22
21
|
import os
|
|
@@ -25,7 +24,6 @@ import signal
|
|
|
25
24
|
import sys
|
|
26
25
|
import threading
|
|
27
26
|
import time
|
|
28
|
-
import uuid
|
|
29
27
|
from collections import deque
|
|
30
28
|
from contextlib import nullcontext
|
|
31
29
|
from datetime import datetime
|
|
@@ -34,13 +32,13 @@ from http import HTTPStatus
|
|
|
34
32
|
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
|
|
35
33
|
|
|
36
34
|
import fastapi
|
|
35
|
+
import orjson
|
|
37
36
|
import torch
|
|
38
37
|
import uvloop
|
|
39
38
|
import zmq
|
|
40
39
|
import zmq.asyncio
|
|
41
40
|
from fastapi import BackgroundTasks
|
|
42
41
|
|
|
43
|
-
from sglang.srt.aio_rwlock import RWLock
|
|
44
42
|
from sglang.srt.configs.model_config import ModelConfig
|
|
45
43
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
|
46
44
|
from sglang.srt.lora.lora_registry import LoRARegistry
|
|
@@ -48,6 +46,7 @@ from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchT
|
|
|
48
46
|
from sglang.srt.managers.disagg_service import start_disagg_service
|
|
49
47
|
from sglang.srt.managers.io_struct import (
|
|
50
48
|
AbortReq,
|
|
49
|
+
BaseReq,
|
|
51
50
|
BatchEmbeddingOutput,
|
|
52
51
|
BatchMultimodalOutput,
|
|
53
52
|
BatchStrOutput,
|
|
@@ -60,7 +59,6 @@ from sglang.srt.managers.io_struct import (
|
|
|
60
59
|
GenerateReqInput,
|
|
61
60
|
GetLoadReqInput,
|
|
62
61
|
HealthCheckOutput,
|
|
63
|
-
MultiTokenizerWrapper,
|
|
64
62
|
OpenSessionReqOutput,
|
|
65
63
|
SessionParams,
|
|
66
64
|
TokenizedEmbeddingReqInput,
|
|
@@ -90,10 +88,10 @@ from sglang.srt.utils import (
|
|
|
90
88
|
dataclass_to_string_truncated,
|
|
91
89
|
freeze_gc,
|
|
92
90
|
get_bool_env_var,
|
|
93
|
-
get_origin_rid,
|
|
94
91
|
get_zmq_socket,
|
|
95
92
|
kill_process_tree,
|
|
96
93
|
)
|
|
94
|
+
from sglang.srt.utils.aio_rwlock import RWLock
|
|
97
95
|
from sglang.srt.utils.hf_transformers_utils import (
|
|
98
96
|
get_processor,
|
|
99
97
|
get_tokenizer,
|
|
@@ -157,7 +155,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
157
155
|
self.log_requests = server_args.log_requests
|
|
158
156
|
self.log_requests_level = server_args.log_requests_level
|
|
159
157
|
self.preferred_sampling_params = (
|
|
160
|
-
|
|
158
|
+
orjson.loads(server_args.preferred_sampling_params)
|
|
161
159
|
if server_args.preferred_sampling_params
|
|
162
160
|
else None
|
|
163
161
|
)
|
|
@@ -182,6 +180,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
182
180
|
if speculative_algorithm.is_none()
|
|
183
181
|
else server_args.speculative_num_draft_tokens
|
|
184
182
|
)
|
|
183
|
+
# Initialize delimiter text for multi-item scoring (will be set after tokenizer is loaded)
|
|
184
|
+
self.multi_item_delimiter_text = None
|
|
185
185
|
|
|
186
186
|
if self.model_config.is_multimodal:
|
|
187
187
|
import_processors("sglang.srt.multimodal.processors")
|
|
@@ -223,6 +223,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
223
223
|
self.processor = _processor
|
|
224
224
|
self.tokenizer = get_tokenizer_from_processor(self.processor)
|
|
225
225
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
226
|
+
self._initialize_multi_item_delimiter_text()
|
|
226
227
|
else:
|
|
227
228
|
self.mm_processor = self.processor = None
|
|
228
229
|
|
|
@@ -235,6 +236,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
235
236
|
trust_remote_code=server_args.trust_remote_code,
|
|
236
237
|
revision=server_args.revision,
|
|
237
238
|
)
|
|
239
|
+
self._initialize_multi_item_delimiter_text()
|
|
238
240
|
# Initialize async dynamic batch tokenizer if enabled (common for both multimodal and non-multimodal)
|
|
239
241
|
if (
|
|
240
242
|
server_args.enable_dynamic_batch_tokenizer
|
|
@@ -255,16 +257,25 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
255
257
|
)
|
|
256
258
|
if self.server_args.tokenizer_worker_num > 1:
|
|
257
259
|
# Use tokenizer_worker_ipc_name in multi-tokenizer mode
|
|
258
|
-
|
|
260
|
+
send_to_scheduler = get_zmq_socket(
|
|
259
261
|
context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False
|
|
260
262
|
)
|
|
263
|
+
|
|
264
|
+
class SenderWrapper:
|
|
265
|
+
def send_pyobj(self, obj):
|
|
266
|
+
if isinstance(obj, BaseReq):
|
|
267
|
+
obj.http_worker_ipc = port_args.tokenizer_ipc_name
|
|
268
|
+
send_to_scheduler.send_pyobj(obj)
|
|
269
|
+
|
|
270
|
+
# Make sure that each request carries the tokenizer_ipc_name for response routing
|
|
271
|
+
self.send_to_scheduler = SenderWrapper()
|
|
261
272
|
else:
|
|
262
273
|
self.send_to_scheduler = get_zmq_socket(
|
|
263
274
|
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
|
|
264
275
|
)
|
|
265
276
|
|
|
266
277
|
# Request states
|
|
267
|
-
self.
|
|
278
|
+
self._chosen_loop = None
|
|
268
279
|
self.rid_to_state: Dict[str, ReqState] = {}
|
|
269
280
|
self.asyncio_tasks = set()
|
|
270
281
|
|
|
@@ -273,6 +284,11 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
273
284
|
self.gracefully_exit = False
|
|
274
285
|
self.last_receive_tstamp = 0
|
|
275
286
|
|
|
287
|
+
# Initial weights status
|
|
288
|
+
self.initial_weights_loaded = True
|
|
289
|
+
if server_args.checkpoint_engine_wait_weights_before_ready:
|
|
290
|
+
self.initial_weights_loaded = False
|
|
291
|
+
|
|
276
292
|
# Dumping
|
|
277
293
|
self.dump_requests_folder = "" # By default do not dump
|
|
278
294
|
self.dump_requests_threshold = 1000
|
|
@@ -355,7 +371,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
355
371
|
(
|
|
356
372
|
FreezeGCReq,
|
|
357
373
|
lambda x: None,
|
|
358
|
-
),
|
|
374
|
+
),
|
|
375
|
+
# For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it.
|
|
359
376
|
(HealthCheckOutput, lambda x: None),
|
|
360
377
|
]
|
|
361
378
|
)
|
|
@@ -372,13 +389,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
372
389
|
obj.normalize_batch_and_arguments()
|
|
373
390
|
|
|
374
391
|
if self.server_args.tokenizer_worker_num > 1:
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
else:
|
|
380
|
-
# If it's a single value, add worker_id prefix
|
|
381
|
-
obj.rid = f"{self.worker_id}_{obj.rid}"
|
|
392
|
+
from sglang.srt.managers.multi_tokenizer_mixin import TokenizerWorker
|
|
393
|
+
|
|
394
|
+
assert isinstance(self, TokenizerWorker)
|
|
395
|
+
self._attach_multi_http_worker_info(obj)
|
|
382
396
|
|
|
383
397
|
if self.enable_trace:
|
|
384
398
|
self._trace_request_start(obj, created_time)
|
|
@@ -582,9 +596,9 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
582
596
|
)
|
|
583
597
|
|
|
584
598
|
if self.mm_processor and obj.contains_mm_input():
|
|
585
|
-
if not isinstance(obj.image_data, list):
|
|
599
|
+
if obj.image_data is not None and not isinstance(obj.image_data, list):
|
|
586
600
|
obj.image_data = [obj.image_data]
|
|
587
|
-
if not isinstance(obj.audio_data, list):
|
|
601
|
+
if obj.audio_data is not None and not isinstance(obj.audio_data, list):
|
|
588
602
|
obj.audio_data = [obj.audio_data]
|
|
589
603
|
mm_inputs: Dict = await self.mm_processor.process_mm_data_async(
|
|
590
604
|
image_data=obj.image_data,
|
|
@@ -724,6 +738,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
724
738
|
obj.token_ids_logprob,
|
|
725
739
|
obj.stream,
|
|
726
740
|
rid=obj.rid,
|
|
741
|
+
http_worker_ipc=obj.http_worker_ipc,
|
|
727
742
|
bootstrap_host=obj.bootstrap_host,
|
|
728
743
|
bootstrap_port=obj.bootstrap_port,
|
|
729
744
|
bootstrap_room=obj.bootstrap_room,
|
|
@@ -745,6 +760,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
745
760
|
sampling_params,
|
|
746
761
|
rid=obj.rid,
|
|
747
762
|
priority=obj.priority,
|
|
763
|
+
http_worker_ipc=obj.http_worker_ipc,
|
|
748
764
|
)
|
|
749
765
|
|
|
750
766
|
return tokenized_obj
|
|
@@ -755,6 +771,14 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
755
771
|
"""Handle batch tokenization for text inputs only."""
|
|
756
772
|
logger.debug(f"Starting batch tokenization for {batch_size} text requests")
|
|
757
773
|
|
|
774
|
+
# If batch does not have text nothing to tokenize
|
|
775
|
+
# so lets construct the return object
|
|
776
|
+
if not self._batch_has_text(batch_size, obj):
|
|
777
|
+
# All requests already have input_ids, no need to tokenize
|
|
778
|
+
return [await self._tokenize_one_request(obj[i]) for i in range(batch_size)]
|
|
779
|
+
|
|
780
|
+
self._validate_batch_tokenization_constraints(batch_size, obj)
|
|
781
|
+
|
|
758
782
|
# Collect requests and texts
|
|
759
783
|
requests = [obj[i] for i in range(batch_size)]
|
|
760
784
|
texts = [req.text for req in requests]
|
|
@@ -804,6 +828,30 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
804
828
|
"Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
|
|
805
829
|
)
|
|
806
830
|
|
|
831
|
+
def _batch_has_text(
|
|
832
|
+
self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
|
|
833
|
+
) -> bool:
|
|
834
|
+
"""Check if any request in the batch contains text input."""
|
|
835
|
+
for i in range(batch_size):
|
|
836
|
+
if obj[i].text:
|
|
837
|
+
return True
|
|
838
|
+
elif self.is_generation and obj[i].contains_mm_input():
|
|
839
|
+
return True
|
|
840
|
+
|
|
841
|
+
return False
|
|
842
|
+
|
|
843
|
+
def _should_use_batch_tokenization(self, batch_size, requests) -> bool:
|
|
844
|
+
"""Return True if we should run the tokenizer in batch mode.
|
|
845
|
+
|
|
846
|
+
Current policy:
|
|
847
|
+
- Respect explicit server flag `enable_tokenizer_batch_encode`.
|
|
848
|
+
- Or, if no request has text or multimodal input (all use pre-tokenized input_ids or input_embeds), batch the requests without tokenization.
|
|
849
|
+
"""
|
|
850
|
+
return batch_size > 0 and (
|
|
851
|
+
self.server_args.enable_tokenizer_batch_encode
|
|
852
|
+
or not self._batch_has_text(batch_size, requests)
|
|
853
|
+
)
|
|
854
|
+
|
|
807
855
|
def _send_one_request(
|
|
808
856
|
self,
|
|
809
857
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
|
@@ -938,13 +986,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
938
986
|
generators = []
|
|
939
987
|
rids = []
|
|
940
988
|
if getattr(obj, "parallel_sample_num", 1) == 1:
|
|
941
|
-
if self.
|
|
942
|
-
# Validate batch tokenization constraints
|
|
943
|
-
self._validate_batch_tokenization_constraints(batch_size, obj)
|
|
944
|
-
|
|
989
|
+
if self._should_use_batch_tokenization(batch_size, obj):
|
|
945
990
|
tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj)
|
|
946
|
-
|
|
947
|
-
# Send as a single batched request
|
|
948
991
|
self._send_batch_request(obj, tokenized_objs, created_time)
|
|
949
992
|
|
|
950
993
|
# Set up generators for each request in the batch
|
|
@@ -1078,8 +1121,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
1078
1121
|
async def _wait_for_model_update_from_disk(
|
|
1079
1122
|
self, obj: UpdateWeightFromDiskReqInput
|
|
1080
1123
|
) -> Tuple[bool, str]:
|
|
1081
|
-
if self.server_args.tokenizer_worker_num > 1:
|
|
1082
|
-
obj = MultiTokenizerWrapper(self.worker_id, obj)
|
|
1083
1124
|
self.send_to_scheduler.send_pyobj(obj)
|
|
1084
1125
|
self.model_update_result = asyncio.Future()
|
|
1085
1126
|
if self.server_args.dp_size == 1:
|
|
@@ -1139,11 +1180,14 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
1139
1180
|
return background_tasks
|
|
1140
1181
|
|
|
1141
1182
|
def auto_create_handle_loop(self):
|
|
1142
|
-
if self.
|
|
1183
|
+
if self._chosen_loop is not None:
|
|
1184
|
+
assert (
|
|
1185
|
+
asyncio.get_event_loop() == self._chosen_loop
|
|
1186
|
+
), f"Please ensure only one event loop is ever used with SGLang. Previous loop: {self._chosen_loop}, current loop: {asyncio.get_event_loop()}"
|
|
1143
1187
|
return
|
|
1144
1188
|
|
|
1145
|
-
self.no_create_loop = True
|
|
1146
1189
|
loop = asyncio.get_event_loop()
|
|
1190
|
+
self._chosen_loop = loop
|
|
1147
1191
|
self.asyncio_tasks.add(
|
|
1148
1192
|
loop.create_task(print_exception_wrapper(self.handle_loop))
|
|
1149
1193
|
)
|
|
@@ -1315,12 +1359,9 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
1315
1359
|
)
|
|
1316
1360
|
continue
|
|
1317
1361
|
|
|
1318
|
-
origin_rid = rid
|
|
1319
|
-
if self.server_args.tokenizer_worker_num > 1:
|
|
1320
|
-
origin_rid = get_origin_rid(rid)
|
|
1321
1362
|
# Build meta_info and return value
|
|
1322
1363
|
meta_info = {
|
|
1323
|
-
"id":
|
|
1364
|
+
"id": rid,
|
|
1324
1365
|
"finish_reason": recv_obj.finished_reasons[i],
|
|
1325
1366
|
"prompt_tokens": recv_obj.prompt_tokens[i],
|
|
1326
1367
|
"weight_version": self.server_args.weight_version,
|
|
@@ -1389,7 +1430,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
1389
1430
|
state.finished = recv_obj.finished_reasons[i] is not None
|
|
1390
1431
|
if state.finished:
|
|
1391
1432
|
if self.server_args.speculative_algorithm:
|
|
1392
|
-
meta_info
|
|
1433
|
+
self._calculate_spec_decoding_metrics(meta_info, recv_obj, i)
|
|
1393
1434
|
state.finished_time = time.time()
|
|
1394
1435
|
meta_info["e2e_latency"] = state.finished_time - state.created_time
|
|
1395
1436
|
|
|
@@ -1537,6 +1578,43 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
1537
1578
|
ret.append(None)
|
|
1538
1579
|
return ret
|
|
1539
1580
|
|
|
1581
|
+
def _calculate_spec_decoding_metrics(
|
|
1582
|
+
self,
|
|
1583
|
+
meta_info: Dict[str, Any],
|
|
1584
|
+
recv_obj: Union[
|
|
1585
|
+
BatchStrOutput,
|
|
1586
|
+
BatchEmbeddingOutput,
|
|
1587
|
+
BatchMultimodalOutput,
|
|
1588
|
+
BatchTokenIDOutput,
|
|
1589
|
+
],
|
|
1590
|
+
i: int,
|
|
1591
|
+
) -> None:
|
|
1592
|
+
"""Calculate speculative decoding metrics, such as acceptance rate and acceptance length metrics."""
|
|
1593
|
+
meta_info["spec_accept_rate"] = 0.0
|
|
1594
|
+
meta_info["spec_accept_length"] = 0
|
|
1595
|
+
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
|
1596
|
+
|
|
1597
|
+
if (
|
|
1598
|
+
recv_obj.spec_verify_ct[i] > 0
|
|
1599
|
+
and self.server_args.speculative_num_steps is not None
|
|
1600
|
+
and not isinstance(recv_obj, BatchEmbeddingOutput)
|
|
1601
|
+
and hasattr(recv_obj, "spec_accepted_tokens")
|
|
1602
|
+
# Checks that `spec_accepted_tokens[i]` will exist.
|
|
1603
|
+
and len(recv_obj.spec_accepted_tokens) > i
|
|
1604
|
+
):
|
|
1605
|
+
total_draft_tokens = (
|
|
1606
|
+
recv_obj.spec_verify_ct[i] * self.server_args.speculative_num_steps
|
|
1607
|
+
)
|
|
1608
|
+
accepted_tokens = recv_obj.spec_accepted_tokens[i]
|
|
1609
|
+
|
|
1610
|
+
# Calculate per-request acceptance rate and average acceptance length.
|
|
1611
|
+
if total_draft_tokens > 0:
|
|
1612
|
+
# Calculate acceptance rate: accepted / (steps * lookahead)
|
|
1613
|
+
meta_info["spec_accept_rate"] = accepted_tokens / total_draft_tokens
|
|
1614
|
+
meta_info["spec_accept_length"] = (
|
|
1615
|
+
recv_obj.completion_tokens[i] / recv_obj.spec_verify_ct[i]
|
|
1616
|
+
)
|
|
1617
|
+
|
|
1540
1618
|
def collect_metrics(self, state: ReqState, recv_obj: BatchStrOutput, i: int):
|
|
1541
1619
|
completion_tokens = (
|
|
1542
1620
|
recv_obj.completion_tokens[i]
|
|
@@ -1637,9 +1715,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
1637
1715
|
if is_health_check_generate_req(recv_obj):
|
|
1638
1716
|
return
|
|
1639
1717
|
state = self.rid_to_state[recv_obj.rid]
|
|
1640
|
-
origin_rid = recv_obj.rid
|
|
1641
|
-
if self.server_args.tokenizer_worker_num > 1:
|
|
1642
|
-
origin_rid = get_origin_rid(origin_rid)
|
|
1643
1718
|
state.finished = True
|
|
1644
1719
|
if recv_obj.finished_reason:
|
|
1645
1720
|
out = {
|
|
@@ -1652,7 +1727,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
1652
1727
|
out = {
|
|
1653
1728
|
"text": "",
|
|
1654
1729
|
"meta_info": {
|
|
1655
|
-
"id":
|
|
1730
|
+
"id": recv_obj.rid,
|
|
1656
1731
|
"finish_reason": {
|
|
1657
1732
|
"type": "abort",
|
|
1658
1733
|
"message": "Abort before prefill",
|
|
@@ -1678,6 +1753,201 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
1678
1753
|
if len(self.model_update_tmp) == self.server_args.dp_size:
|
|
1679
1754
|
self.model_update_result.set_result(self.model_update_tmp)
|
|
1680
1755
|
|
|
1756
|
+
def _initialize_multi_item_delimiter_text(self):
|
|
1757
|
+
"""Initialize multi-item delimiter text from token ID after tokenizer is loaded."""
|
|
1758
|
+
if (
|
|
1759
|
+
hasattr(self.server_args, "multi_item_scoring_delimiter")
|
|
1760
|
+
and self.server_args.multi_item_scoring_delimiter is not None
|
|
1761
|
+
and self.tokenizer is not None
|
|
1762
|
+
):
|
|
1763
|
+
try:
|
|
1764
|
+
self.multi_item_delimiter_text = self.tokenizer.decode(
|
|
1765
|
+
[self.server_args.multi_item_scoring_delimiter],
|
|
1766
|
+
skip_special_tokens=False,
|
|
1767
|
+
)
|
|
1768
|
+
except Exception as e:
|
|
1769
|
+
logger.warning(
|
|
1770
|
+
f"Failed to decode delimiter token {self.server_args.multi_item_scoring_delimiter}: {e}"
|
|
1771
|
+
)
|
|
1772
|
+
self.multi_item_delimiter_text = None
|
|
1773
|
+
|
|
1774
|
+
def _build_multi_item_token_sequence(
|
|
1775
|
+
self, query: List[int], items: List[List[int]], delimiter_token_id: int
|
|
1776
|
+
) -> List[int]:
|
|
1777
|
+
"""
|
|
1778
|
+
Build a single token sequence for multi-item scoring.
|
|
1779
|
+
Format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter>
|
|
1780
|
+
|
|
1781
|
+
Args:
|
|
1782
|
+
query: Query token IDs
|
|
1783
|
+
items: List of item token ID sequences
|
|
1784
|
+
delimiter_token_id: Token ID to use as delimiter
|
|
1785
|
+
|
|
1786
|
+
Returns:
|
|
1787
|
+
Combined token sequence
|
|
1788
|
+
"""
|
|
1789
|
+
combined_sequence = query[:] # Start with query
|
|
1790
|
+
|
|
1791
|
+
for item in items:
|
|
1792
|
+
combined_sequence.append(delimiter_token_id) # Add delimiter
|
|
1793
|
+
combined_sequence.extend(item) # Add item tokens
|
|
1794
|
+
|
|
1795
|
+
# Add final delimiter after the last item for logprob extraction
|
|
1796
|
+
combined_sequence.append(delimiter_token_id)
|
|
1797
|
+
|
|
1798
|
+
return combined_sequence
|
|
1799
|
+
|
|
1800
|
+
def _extract_logprobs_for_tokens(
|
|
1801
|
+
self, logprobs_data: List, label_token_ids: List[int]
|
|
1802
|
+
) -> Dict[int, float]:
|
|
1803
|
+
"""
|
|
1804
|
+
Extract logprobs for specified token IDs from logprobs data.
|
|
1805
|
+
|
|
1806
|
+
Args:
|
|
1807
|
+
logprobs_data: List of (logprob, token_id, text) tuples
|
|
1808
|
+
label_token_ids: Token IDs to extract logprobs for
|
|
1809
|
+
|
|
1810
|
+
Returns:
|
|
1811
|
+
Dictionary mapping token_id to logprob
|
|
1812
|
+
"""
|
|
1813
|
+
logprobs = {}
|
|
1814
|
+
if logprobs_data:
|
|
1815
|
+
for logprob, token_id, _ in logprobs_data:
|
|
1816
|
+
if token_id in label_token_ids:
|
|
1817
|
+
logprobs[token_id] = logprob
|
|
1818
|
+
return logprobs
|
|
1819
|
+
|
|
1820
|
+
def _convert_logprobs_to_scores(
|
|
1821
|
+
self,
|
|
1822
|
+
logprobs: Dict[int, float],
|
|
1823
|
+
label_token_ids: List[int],
|
|
1824
|
+
apply_softmax: bool,
|
|
1825
|
+
) -> List[float]:
|
|
1826
|
+
"""
|
|
1827
|
+
Convert logprobs dictionary to ordered score list.
|
|
1828
|
+
|
|
1829
|
+
Args:
|
|
1830
|
+
logprobs: Dictionary mapping token_id to logprob
|
|
1831
|
+
label_token_ids: Token IDs in desired order
|
|
1832
|
+
apply_softmax: Whether to apply softmax normalization
|
|
1833
|
+
|
|
1834
|
+
Returns:
|
|
1835
|
+
List of scores in the same order as label_token_ids
|
|
1836
|
+
"""
|
|
1837
|
+
score_list = [
|
|
1838
|
+
logprobs.get(token_id, float("-inf")) for token_id in label_token_ids
|
|
1839
|
+
]
|
|
1840
|
+
|
|
1841
|
+
if apply_softmax:
|
|
1842
|
+
score_list = torch.softmax(torch.tensor(score_list), dim=0).tolist()
|
|
1843
|
+
else:
|
|
1844
|
+
# Convert logprobs to probabilities if not using softmax
|
|
1845
|
+
score_list = [
|
|
1846
|
+
math.exp(x) if x != float("-inf") else 0.0 for x in score_list
|
|
1847
|
+
]
|
|
1848
|
+
|
|
1849
|
+
return score_list
|
|
1850
|
+
|
|
1851
|
+
def _process_multi_item_scoring_results(
|
|
1852
|
+
self,
|
|
1853
|
+
results: Any,
|
|
1854
|
+
items: List,
|
|
1855
|
+
label_token_ids: List[int],
|
|
1856
|
+
apply_softmax: bool,
|
|
1857
|
+
batch_request=None,
|
|
1858
|
+
) -> List[List[float]]:
|
|
1859
|
+
"""
|
|
1860
|
+
Process results from multi-item scoring request.
|
|
1861
|
+
Extracts logprobs at delimiter positions from input_token_ids_logprobs.
|
|
1862
|
+
|
|
1863
|
+
Args:
|
|
1864
|
+
results: Results from generate_request
|
|
1865
|
+
items: List of items being scored
|
|
1866
|
+
label_token_ids: Token IDs to extract scores for
|
|
1867
|
+
apply_softmax: Whether to apply softmax normalization
|
|
1868
|
+
batch_request: The original batch request containing input sequence
|
|
1869
|
+
|
|
1870
|
+
Returns:
|
|
1871
|
+
List of score lists, one for each item
|
|
1872
|
+
"""
|
|
1873
|
+
single_result = results[0] if isinstance(results, list) else results
|
|
1874
|
+
|
|
1875
|
+
# For multi-item scoring, logprobs are in input_token_ids_logprobs
|
|
1876
|
+
input_logprobs = single_result["meta_info"].get("input_token_ids_logprobs", [])
|
|
1877
|
+
|
|
1878
|
+
if not input_logprobs:
|
|
1879
|
+
raise RuntimeError(
|
|
1880
|
+
f"input_token_ids_logprobs is empty for multi-item scoring request {single_result['meta_info'].get('id', '<unknown>')}. "
|
|
1881
|
+
"This indicates token_ids_logprobs were not computed properly for Mutil Item Scoring."
|
|
1882
|
+
)
|
|
1883
|
+
|
|
1884
|
+
scores = []
|
|
1885
|
+
num_items = len(items) if isinstance(items, list) else 1
|
|
1886
|
+
|
|
1887
|
+
# Check if we have the expected number of logprobs
|
|
1888
|
+
expected_logprobs_count = num_items + 1
|
|
1889
|
+
if len(input_logprobs) != expected_logprobs_count:
|
|
1890
|
+
raise RuntimeError(
|
|
1891
|
+
f"Expected {expected_logprobs_count} input_token_ids_logprobs for multi-item scoring "
|
|
1892
|
+
f"with {num_items} items, but got {len(input_logprobs)}. "
|
|
1893
|
+
f"Request ID: {single_result['meta_info'].get('id', '<unknown>')}"
|
|
1894
|
+
)
|
|
1895
|
+
|
|
1896
|
+
# Skip the first delimiter (between query and first item) and process remaining delimiter positions
|
|
1897
|
+
# We want to exclude the first one since it represents the boundary between query and first item, not an item boundary
|
|
1898
|
+
start_idx = 1 if len(input_logprobs) > 1 else 0
|
|
1899
|
+
|
|
1900
|
+
# Process logprobs for each item position (excluding first delimiter)
|
|
1901
|
+
for item_idx in range(num_items):
|
|
1902
|
+
logprob_idx = start_idx + item_idx
|
|
1903
|
+
item_logprobs_data = input_logprobs[logprob_idx]
|
|
1904
|
+
logprobs = self._extract_logprobs_for_tokens(
|
|
1905
|
+
item_logprobs_data, label_token_ids
|
|
1906
|
+
)
|
|
1907
|
+
score_list = self._convert_logprobs_to_scores(
|
|
1908
|
+
logprobs, label_token_ids, apply_softmax
|
|
1909
|
+
)
|
|
1910
|
+
scores.append(score_list)
|
|
1911
|
+
|
|
1912
|
+
return scores
|
|
1913
|
+
|
|
1914
|
+
def _process_single_item_scoring_results(
|
|
1915
|
+
self, results: Any, label_token_ids: List[int], apply_softmax: bool
|
|
1916
|
+
) -> List[List[float]]:
|
|
1917
|
+
"""
|
|
1918
|
+
Process results from single-item scoring request.
|
|
1919
|
+
Single-item scoring results are stored in output_token_ids_logprobs.
|
|
1920
|
+
|
|
1921
|
+
Args:
|
|
1922
|
+
results: Results from generate_request
|
|
1923
|
+
label_token_ids: Token IDs to extract scores for
|
|
1924
|
+
apply_softmax: Whether to apply softmax normalization
|
|
1925
|
+
|
|
1926
|
+
Returns:
|
|
1927
|
+
List of score lists, one for each result
|
|
1928
|
+
"""
|
|
1929
|
+
scores = []
|
|
1930
|
+
|
|
1931
|
+
for result in results:
|
|
1932
|
+
# For single-item scoring, logprobs are in output_token_ids_logprobs
|
|
1933
|
+
output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
|
|
1934
|
+
|
|
1935
|
+
if not output_logprobs or len(output_logprobs) == 0:
|
|
1936
|
+
raise RuntimeError(
|
|
1937
|
+
f"output_logprobs is empty for request {result['meta_info'].get('id', '<unknown>')}."
|
|
1938
|
+
)
|
|
1939
|
+
|
|
1940
|
+
# Extract logprobs for the first (and only) position
|
|
1941
|
+
logprobs = self._extract_logprobs_for_tokens(
|
|
1942
|
+
output_logprobs[0], label_token_ids
|
|
1943
|
+
)
|
|
1944
|
+
score_list = self._convert_logprobs_to_scores(
|
|
1945
|
+
logprobs, label_token_ids, apply_softmax
|
|
1946
|
+
)
|
|
1947
|
+
scores.append(score_list)
|
|
1948
|
+
|
|
1949
|
+
return scores
|
|
1950
|
+
|
|
1681
1951
|
async def score_request(
|
|
1682
1952
|
self,
|
|
1683
1953
|
query: Optional[Union[str, List[int]]] = None,
|
|
@@ -1688,7 +1958,29 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
1688
1958
|
request: Optional[Any] = None,
|
|
1689
1959
|
) -> List[List[float]]:
|
|
1690
1960
|
"""
|
|
1691
|
-
|
|
1961
|
+
Score the probability of specified token IDs appearing after the given (query + item) pair.
|
|
1962
|
+
|
|
1963
|
+
This method supports two scoring approaches:
|
|
1964
|
+
1. Single-Item scoring (default): Process each query+item pair independently
|
|
1965
|
+
2. Multi-Item scoring: When multi_item_scoring_delimiter is set, combine query and
|
|
1966
|
+
multiple items into a single sequence using delimiter for efficient processing.
|
|
1967
|
+
Note: item_first parameter is ignored in multi-item scoring mode since it uses
|
|
1968
|
+
a fixed format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter>
|
|
1969
|
+
|
|
1970
|
+
Multi-item scoring works with both text and pre-tokenized inputs:
|
|
1971
|
+
- Text: query<delimiter_text>item1<delimiter_text>item2<delimiter_text>item3<delimiter_text>
|
|
1972
|
+
- Tokens: query<delimiter_token_id>item1<delimiter_token_id>item2<delimiter_token_id>item3<delimiter_token_id>
|
|
1973
|
+
|
|
1974
|
+
Args:
|
|
1975
|
+
query: The query text or pre-tokenized query token IDs
|
|
1976
|
+
items: The item text(s) or pre-tokenized item token IDs
|
|
1977
|
+
label_token_ids: List of token IDs to compute probabilities for
|
|
1978
|
+
apply_softmax: Whether to normalize probabilities using softmax
|
|
1979
|
+
item_first: If True, prepend items to query. Ignored for multi-item scoring.
|
|
1980
|
+
request: Optional FastAPI request object
|
|
1981
|
+
|
|
1982
|
+
Returns:
|
|
1983
|
+
List of lists containing probabilities for each item and each label token
|
|
1692
1984
|
"""
|
|
1693
1985
|
if label_token_ids is None:
|
|
1694
1986
|
raise ValueError("label_token_ids must be provided")
|
|
@@ -1701,9 +1993,17 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
1701
1993
|
f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})"
|
|
1702
1994
|
)
|
|
1703
1995
|
|
|
1996
|
+
# Check if multi-item scoring is enabled by presence of delimiter
|
|
1997
|
+
use_multi_item_scoring = (
|
|
1998
|
+
self.server_args.multi_item_scoring_delimiter is not None
|
|
1999
|
+
and self.multi_item_delimiter_text is not None
|
|
2000
|
+
)
|
|
2001
|
+
|
|
1704
2002
|
batch_request = GenerateReqInput(
|
|
1705
2003
|
token_ids_logprob=label_token_ids,
|
|
1706
2004
|
return_logprob=True,
|
|
2005
|
+
# Set logprob_start_len=0 for multi-item scoring since we want logprobs at all delimiter positions
|
|
2006
|
+
logprob_start_len=0 if use_multi_item_scoring else -1,
|
|
1707
2007
|
stream=False,
|
|
1708
2008
|
sampling_params={"max_new_tokens": 0},
|
|
1709
2009
|
)
|
|
@@ -1715,12 +2015,23 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
1715
2015
|
):
|
|
1716
2016
|
# Both query and items are text
|
|
1717
2017
|
items_list = [items] if isinstance(items, str) else items
|
|
1718
|
-
if item_first:
|
|
1719
|
-
prompts = [f"{item}{query}" for item in items_list]
|
|
1720
|
-
else:
|
|
1721
|
-
prompts = [f"{query}{item}" for item in items_list]
|
|
1722
2018
|
|
|
1723
|
-
|
|
2019
|
+
if use_multi_item_scoring:
|
|
2020
|
+
# Multi-item scoring: create single prompt with delimiter text
|
|
2021
|
+
# Always use format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter>
|
|
2022
|
+
# (item_first is ignored for multi-item scoring)
|
|
2023
|
+
delimiter = self.multi_item_delimiter_text
|
|
2024
|
+
combined_items = delimiter.join(items_list)
|
|
2025
|
+
# Add final delimiter after the last item for logprob extraction
|
|
2026
|
+
single_prompt = f"{query}{delimiter}{combined_items}{delimiter}"
|
|
2027
|
+
batch_request.text = [single_prompt]
|
|
2028
|
+
else:
|
|
2029
|
+
# Single-item scoring: create separate prompts for each item
|
|
2030
|
+
if item_first:
|
|
2031
|
+
prompts = [f"{item}{query}" for item in items_list]
|
|
2032
|
+
else:
|
|
2033
|
+
prompts = [f"{query}{item}" for item in items_list]
|
|
2034
|
+
batch_request.text = prompts
|
|
1724
2035
|
|
|
1725
2036
|
elif (
|
|
1726
2037
|
isinstance(query, list)
|
|
@@ -1729,61 +2040,38 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
1729
2040
|
and isinstance(items[0], list)
|
|
1730
2041
|
):
|
|
1731
2042
|
# Both query and items are token IDs
|
|
1732
|
-
if
|
|
1733
|
-
|
|
2043
|
+
if use_multi_item_scoring:
|
|
2044
|
+
# Multi-item scoring: concatenate with delimiter token ID
|
|
2045
|
+
# Format: query<delimiter_token_id>item1<delimiter_token_id>item2<delimiter_token_id>item3<delimiter_token_id>
|
|
2046
|
+
delimiter_token_id = self.server_args.multi_item_scoring_delimiter
|
|
2047
|
+
combined_input_ids = self._build_multi_item_token_sequence(
|
|
2048
|
+
query, items, delimiter_token_id
|
|
2049
|
+
)
|
|
2050
|
+
batch_request.input_ids = [combined_input_ids]
|
|
1734
2051
|
else:
|
|
1735
|
-
|
|
1736
|
-
|
|
1737
|
-
|
|
2052
|
+
# Single-item scoring: process each item separately
|
|
2053
|
+
if item_first:
|
|
2054
|
+
input_ids_list = [item + query for item in items]
|
|
2055
|
+
else:
|
|
2056
|
+
input_ids_list = [query + item for item in items]
|
|
2057
|
+
batch_request.input_ids = input_ids_list
|
|
1738
2058
|
else:
|
|
1739
2059
|
raise ValueError(
|
|
1740
2060
|
"Invalid combination of query/items types for score_request."
|
|
1741
2061
|
)
|
|
1742
2062
|
|
|
1743
2063
|
results = await self.generate_request(batch_request, request).__anext__()
|
|
1744
|
-
scores = []
|
|
1745
|
-
|
|
1746
|
-
for result in results:
|
|
1747
|
-
# Get logprobs for each token
|
|
1748
|
-
logprobs = {}
|
|
1749
|
-
|
|
1750
|
-
# For scoring requests, we read from output_token_ids_logprobs since we want
|
|
1751
|
-
# the logprobs for specific tokens mentioned in the label_token_ids at
|
|
1752
|
-
# the next position after the last token in the prompt
|
|
1753
|
-
output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
|
|
1754
2064
|
|
|
1755
|
-
|
|
1756
|
-
|
|
1757
|
-
|
|
1758
|
-
|
|
1759
|
-
|
|
1760
|
-
|
|
1761
|
-
|
|
1762
|
-
|
|
1763
|
-
|
|
1764
|
-
|
|
1765
|
-
|
|
1766
|
-
for logprob, token_id, _ in output_logprobs[0]:
|
|
1767
|
-
if token_id in label_token_ids:
|
|
1768
|
-
logprobs[token_id] = logprob
|
|
1769
|
-
|
|
1770
|
-
# Get scores in order of label_token_ids
|
|
1771
|
-
score_list = [
|
|
1772
|
-
logprobs.get(token_id, float("-inf")) for token_id in label_token_ids
|
|
1773
|
-
]
|
|
1774
|
-
|
|
1775
|
-
# Apply softmax to logprobs if needed
|
|
1776
|
-
if apply_softmax:
|
|
1777
|
-
score_list = torch.softmax(torch.tensor(score_list), dim=0).tolist()
|
|
1778
|
-
else:
|
|
1779
|
-
# Convert logprobs to probabilities if not using softmax
|
|
1780
|
-
score_list = [
|
|
1781
|
-
math.exp(x) if x != float("-inf") else 0.0 for x in score_list
|
|
1782
|
-
]
|
|
1783
|
-
|
|
1784
|
-
scores.append(score_list)
|
|
1785
|
-
|
|
1786
|
-
return scores
|
|
2065
|
+
if use_multi_item_scoring:
|
|
2066
|
+
# Multi-item scoring: extract scores from input_token_ids_logprobs
|
|
2067
|
+
return self._process_multi_item_scoring_results(
|
|
2068
|
+
results, items, label_token_ids, apply_softmax, batch_request
|
|
2069
|
+
)
|
|
2070
|
+
else:
|
|
2071
|
+
# Single-item scoring: process each result separately
|
|
2072
|
+
return self._process_single_item_scoring_results(
|
|
2073
|
+
results, label_token_ids, apply_softmax
|
|
2074
|
+
)
|
|
1787
2075
|
|
|
1788
2076
|
async def watch_load_thread(self):
|
|
1789
2077
|
# Only for dp_controller when dp_size > 1
|