sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -11
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +474 -142
- sglang/compile_deep_gemm.py +3 -0
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +10 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +314 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +228 -92
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +294 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +78 -37
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +373 -68
- sglang/srt/disaggregation/prefill.py +53 -49
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +842 -0
- sglang/srt/entrypoints/grpc_server.py +950 -0
- sglang/srt/entrypoints/http_server.py +179 -60
- sglang/srt/entrypoints/openai/protocol.py +265 -29
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +213 -122
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +289 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +17 -8
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +215 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +40 -8
- sglang/srt/layers/attention/flashinfer_backend.py +341 -204
- sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
- sglang/srt/layers/attention/mamba/mamba.py +577 -0
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +180 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
- sglang/srt/layers/moe/ep_moe/layer.py +248 -333
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +83 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +29 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +155 -60
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +191 -56
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +28 -33
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +44 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +55 -0
- sglang/srt/managers/schedule_batch.py +343 -212
- sglang/srt/managers/schedule_policy.py +145 -18
- sglang/srt/managers/scheduler.py +653 -273
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +579 -674
- sglang/srt/managers/tp_worker.py +96 -26
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +9 -2
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +651 -80
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +227 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +93 -48
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +74 -46
- sglang/srt/model_executor/model_runner.py +455 -176
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +10 -4
- sglang/srt/model_loader/loader.py +319 -10
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +161 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +578 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +50 -4
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +55 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +49 -26
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1051 -285
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +98 -29
- sglang/srt/speculative/ngram_info.py +428 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +605 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +451 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +119 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +313 -0
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +140 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +407 -8
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -31,18 +31,7 @@ from contextlib import nullcontext
|
|
31
31
|
from datetime import datetime
|
32
32
|
from enum import Enum
|
33
33
|
from http import HTTPStatus
|
34
|
-
from typing import
|
35
|
-
Any,
|
36
|
-
Awaitable,
|
37
|
-
Deque,
|
38
|
-
Dict,
|
39
|
-
Generic,
|
40
|
-
List,
|
41
|
-
Optional,
|
42
|
-
Tuple,
|
43
|
-
TypeVar,
|
44
|
-
Union,
|
45
|
-
)
|
34
|
+
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
|
46
35
|
|
47
36
|
import fastapi
|
48
37
|
import torch
|
@@ -53,80 +42,49 @@ from fastapi import BackgroundTasks
|
|
53
42
|
|
54
43
|
from sglang.srt.aio_rwlock import RWLock
|
55
44
|
from sglang.srt.configs.model_config import ModelConfig
|
56
|
-
from sglang.srt.disaggregation.utils import
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
get_kv_class,
|
61
|
-
)
|
62
|
-
from sglang.srt.hf_transformers_utils import (
|
63
|
-
get_processor,
|
64
|
-
get_tokenizer,
|
65
|
-
get_tokenizer_from_processor,
|
66
|
-
)
|
67
|
-
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
|
45
|
+
from sglang.srt.disaggregation.utils import DisaggregationMode
|
46
|
+
from sglang.srt.lora.lora_registry import LoRARegistry
|
47
|
+
from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer
|
48
|
+
from sglang.srt.managers.disagg_service import start_disagg_service
|
68
49
|
from sglang.srt.managers.io_struct import (
|
69
50
|
AbortReq,
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
51
|
+
BatchEmbeddingOutput,
|
52
|
+
BatchMultimodalOutput,
|
53
|
+
BatchStrOutput,
|
54
|
+
BatchTokenIDOutput,
|
74
55
|
BatchTokenizedEmbeddingReqInput,
|
75
56
|
BatchTokenizedGenerateReqInput,
|
76
|
-
ClearHiCacheReqInput,
|
77
|
-
ClearHiCacheReqOutput,
|
78
|
-
CloseSessionReqInput,
|
79
57
|
ConfigureLoggingReq,
|
80
58
|
EmbeddingReqInput,
|
81
|
-
ExpertDistributionReq,
|
82
|
-
ExpertDistributionReqOutput,
|
83
|
-
FlushCacheReqInput,
|
84
|
-
FlushCacheReqOutput,
|
85
59
|
FreezeGCReq,
|
86
60
|
GenerateReqInput,
|
87
|
-
|
88
|
-
GetInternalStateReqOutput,
|
89
|
-
GetWeightsByNameReqInput,
|
90
|
-
GetWeightsByNameReqOutput,
|
61
|
+
GetLoadReqInput,
|
91
62
|
HealthCheckOutput,
|
92
|
-
|
93
|
-
InitWeightsUpdateGroupReqOutput,
|
94
|
-
LoadLoRAAdapterReqInput,
|
95
|
-
LoadLoRAAdapterReqOutput,
|
96
|
-
LoRAUpdateResult,
|
97
|
-
MultiTokenizerWarpper,
|
98
|
-
OpenSessionReqInput,
|
63
|
+
MultiTokenizerWrapper,
|
99
64
|
OpenSessionReqOutput,
|
100
|
-
ProfileReq,
|
101
|
-
ProfileReqOutput,
|
102
|
-
ProfileReqType,
|
103
|
-
ReleaseMemoryOccupationReqInput,
|
104
|
-
ReleaseMemoryOccupationReqOutput,
|
105
|
-
ResumeMemoryOccupationReqInput,
|
106
|
-
ResumeMemoryOccupationReqOutput,
|
107
65
|
SessionParams,
|
108
|
-
SetInternalStateReq,
|
109
|
-
SetInternalStateReqOutput,
|
110
|
-
SlowDownReqInput,
|
111
|
-
SlowDownReqOutput,
|
112
66
|
TokenizedEmbeddingReqInput,
|
113
67
|
TokenizedGenerateReqInput,
|
114
|
-
UnloadLoRAAdapterReqInput,
|
115
|
-
UnloadLoRAAdapterReqOutput,
|
116
68
|
UpdateWeightFromDiskReqInput,
|
117
69
|
UpdateWeightFromDiskReqOutput,
|
118
|
-
|
119
|
-
UpdateWeightsFromDistributedReqOutput,
|
120
|
-
UpdateWeightsFromTensorReqInput,
|
121
|
-
UpdateWeightsFromTensorReqOutput,
|
70
|
+
WatchLoadUpdateReq,
|
122
71
|
)
|
123
72
|
from sglang.srt.managers.mm_utils import TensorTransportMode
|
124
73
|
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
|
125
74
|
from sglang.srt.managers.scheduler import is_health_check_generate_req
|
126
75
|
from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
|
76
|
+
from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicatorMixin
|
127
77
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
128
78
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
129
79
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
80
|
+
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
81
|
+
from sglang.srt.tracing.trace import (
|
82
|
+
trace_get_proc_propagate_context,
|
83
|
+
trace_req_finish,
|
84
|
+
trace_req_start,
|
85
|
+
trace_slice_end,
|
86
|
+
trace_slice_start,
|
87
|
+
)
|
130
88
|
from sglang.srt.utils import (
|
131
89
|
configure_gc_warning,
|
132
90
|
dataclass_to_string_truncated,
|
@@ -136,6 +94,11 @@ from sglang.srt.utils import (
|
|
136
94
|
get_zmq_socket,
|
137
95
|
kill_process_tree,
|
138
96
|
)
|
97
|
+
from sglang.srt.utils.hf_transformers_utils import (
|
98
|
+
get_processor,
|
99
|
+
get_tokenizer,
|
100
|
+
get_tokenizer_from_processor,
|
101
|
+
)
|
139
102
|
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
140
103
|
|
141
104
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
@@ -180,7 +143,7 @@ class ReqState:
|
|
180
143
|
output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
|
181
144
|
|
182
145
|
|
183
|
-
class TokenizerManager:
|
146
|
+
class TokenizerManager(TokenizerCommunicatorMixin):
|
184
147
|
"""TokenizerManager is a process that tokenizes the text."""
|
185
148
|
|
186
149
|
def __init__(
|
@@ -199,6 +162,7 @@ class TokenizerManager:
|
|
199
162
|
else None
|
200
163
|
)
|
201
164
|
self.crash_dump_folder = server_args.crash_dump_folder
|
165
|
+
self.enable_trace = server_args.enable_trace
|
202
166
|
|
203
167
|
# Read model args
|
204
168
|
self.model_path = server_args.model_path
|
@@ -210,8 +174,19 @@ class TokenizerManager:
|
|
210
174
|
self.image_token_id = self.model_config.image_token_id
|
211
175
|
self.max_req_input_len = None # Will be set later in engine.py
|
212
176
|
|
177
|
+
speculative_algorithm = SpeculativeAlgorithm.from_string(
|
178
|
+
server_args.speculative_algorithm
|
179
|
+
)
|
180
|
+
self.reserve_input_token_num = (
|
181
|
+
0
|
182
|
+
if speculative_algorithm.is_none()
|
183
|
+
else server_args.speculative_num_draft_tokens
|
184
|
+
)
|
185
|
+
# Initialize delimiter text for multi-item scoring (will be set after tokenizer is loaded)
|
186
|
+
self.multi_item_delimiter_text = None
|
187
|
+
|
213
188
|
if self.model_config.is_multimodal:
|
214
|
-
import_processors()
|
189
|
+
import_processors("sglang.srt.multimodal.processors")
|
215
190
|
try:
|
216
191
|
_processor = get_processor(
|
217
192
|
server_args.tokenizer_path,
|
@@ -250,6 +225,7 @@ class TokenizerManager:
|
|
250
225
|
self.processor = _processor
|
251
226
|
self.tokenizer = get_tokenizer_from_processor(self.processor)
|
252
227
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
228
|
+
self._initialize_multi_item_delimiter_text()
|
253
229
|
else:
|
254
230
|
self.mm_processor = self.processor = None
|
255
231
|
|
@@ -262,6 +238,19 @@ class TokenizerManager:
|
|
262
238
|
trust_remote_code=server_args.trust_remote_code,
|
263
239
|
revision=server_args.revision,
|
264
240
|
)
|
241
|
+
self._initialize_multi_item_delimiter_text()
|
242
|
+
# Initialize async dynamic batch tokenizer if enabled (common for both multimodal and non-multimodal)
|
243
|
+
if (
|
244
|
+
server_args.enable_dynamic_batch_tokenizer
|
245
|
+
and not server_args.skip_tokenizer_init
|
246
|
+
):
|
247
|
+
self.async_dynamic_batch_tokenizer = AsyncDynamicbatchTokenizer(
|
248
|
+
self.tokenizer,
|
249
|
+
max_batch_size=server_args.dynamic_batch_tokenizer_batch_size,
|
250
|
+
batch_wait_timeout_s=server_args.dynamic_batch_tokenizer_batch_timeout,
|
251
|
+
)
|
252
|
+
else:
|
253
|
+
self.async_dynamic_batch_tokenizer = None
|
265
254
|
|
266
255
|
# Init inter-process communication
|
267
256
|
context = zmq.asyncio.Context(2)
|
@@ -319,8 +308,10 @@ class TokenizerManager:
|
|
319
308
|
# LoRA updates and inference to overlap.
|
320
309
|
self.lora_update_lock = asyncio.Lock()
|
321
310
|
|
322
|
-
|
323
|
-
|
311
|
+
self.disaggregation_mode = DisaggregationMode(
|
312
|
+
self.server_args.disaggregation_mode
|
313
|
+
)
|
314
|
+
self.bootstrap_server = start_disagg_service(self.server_args)
|
324
315
|
|
325
316
|
# For load balancing
|
326
317
|
self.current_load = 0
|
@@ -328,12 +319,16 @@ class TokenizerManager:
|
|
328
319
|
|
329
320
|
# Metrics
|
330
321
|
if self.enable_metrics:
|
322
|
+
labels = {
|
323
|
+
"model_name": self.server_args.served_model_name,
|
324
|
+
# TODO: Add lora name/path in the future,
|
325
|
+
}
|
326
|
+
if server_args.tokenizer_metrics_allowed_custom_labels:
|
327
|
+
for label in server_args.tokenizer_metrics_allowed_custom_labels:
|
328
|
+
labels[label] = ""
|
331
329
|
self.metrics_collector = TokenizerMetricsCollector(
|
332
330
|
server_args=server_args,
|
333
|
-
labels=
|
334
|
-
"model_name": self.server_args.served_model_name,
|
335
|
-
# TODO: Add lora name/path in the future,
|
336
|
-
},
|
331
|
+
labels=labels,
|
337
332
|
bucket_time_to_first_token=self.server_args.bucket_time_to_first_token,
|
338
333
|
bucket_e2e_request_latency=self.server_args.bucket_e2e_request_latency,
|
339
334
|
bucket_inter_token_latency=self.server_args.bucket_inter_token_latency,
|
@@ -344,58 +339,14 @@ class TokenizerManager:
|
|
344
339
|
if self.server_args.gc_warning_threshold_secs > 0.0:
|
345
340
|
configure_gc_warning(self.server_args.gc_warning_threshold_secs)
|
346
341
|
|
347
|
-
# Communicators
|
348
|
-
self.init_weights_update_group_communicator = _Communicator(
|
349
|
-
self.send_to_scheduler, server_args.dp_size
|
350
|
-
)
|
351
|
-
self.update_weights_from_distributed_communicator = _Communicator(
|
352
|
-
self.send_to_scheduler, server_args.dp_size
|
353
|
-
)
|
354
|
-
self.update_weights_from_tensor_communicator = _Communicator(
|
355
|
-
self.send_to_scheduler, server_args.dp_size
|
356
|
-
)
|
357
|
-
self.get_weights_by_name_communicator = _Communicator(
|
358
|
-
self.send_to_scheduler, server_args.dp_size
|
359
|
-
)
|
360
|
-
self.release_memory_occupation_communicator = _Communicator(
|
361
|
-
self.send_to_scheduler, server_args.dp_size
|
362
|
-
)
|
363
|
-
self.resume_memory_occupation_communicator = _Communicator(
|
364
|
-
self.send_to_scheduler, server_args.dp_size
|
365
|
-
)
|
366
|
-
self.slow_down_communicator = _Communicator(
|
367
|
-
self.send_to_scheduler, server_args.dp_size
|
368
|
-
)
|
369
|
-
self.flush_cache_communicator = _Communicator(
|
370
|
-
self.send_to_scheduler, server_args.dp_size
|
371
|
-
)
|
372
|
-
self.clear_hicache_storage_communicator = _Communicator(
|
373
|
-
self.send_to_scheduler, server_args.dp_size
|
374
|
-
)
|
375
|
-
self.profile_communicator = _Communicator(
|
376
|
-
self.send_to_scheduler, server_args.dp_size
|
377
|
-
)
|
378
|
-
self.get_internal_state_communicator = _Communicator(
|
379
|
-
self.send_to_scheduler, server_args.dp_size
|
380
|
-
)
|
381
|
-
self.set_internal_state_communicator = _Communicator(
|
382
|
-
self.send_to_scheduler, server_args.dp_size
|
383
|
-
)
|
384
|
-
self.expert_distribution_communicator = _Communicator(
|
385
|
-
self.send_to_scheduler, server_args.dp_size
|
386
|
-
)
|
387
|
-
self.update_lora_adapter_communicator = _Communicator(
|
388
|
-
self.send_to_scheduler, server_args.dp_size
|
389
|
-
)
|
390
|
-
|
391
342
|
self._result_dispatcher = TypeBasedDispatcher(
|
392
343
|
[
|
393
344
|
(
|
394
345
|
(
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
346
|
+
BatchStrOutput,
|
347
|
+
BatchEmbeddingOutput,
|
348
|
+
BatchTokenIDOutput,
|
349
|
+
BatchMultimodalOutput,
|
399
350
|
),
|
400
351
|
self._handle_batch_output,
|
401
352
|
),
|
@@ -405,100 +356,15 @@ class TokenizerManager:
|
|
405
356
|
UpdateWeightFromDiskReqOutput,
|
406
357
|
self._handle_update_weights_from_disk_req_output,
|
407
358
|
),
|
408
|
-
(
|
409
|
-
InitWeightsUpdateGroupReqOutput,
|
410
|
-
self.init_weights_update_group_communicator.handle_recv,
|
411
|
-
),
|
412
|
-
(
|
413
|
-
UpdateWeightsFromDistributedReqOutput,
|
414
|
-
self.update_weights_from_distributed_communicator.handle_recv,
|
415
|
-
),
|
416
|
-
(
|
417
|
-
UpdateWeightsFromTensorReqOutput,
|
418
|
-
self.update_weights_from_tensor_communicator.handle_recv,
|
419
|
-
),
|
420
|
-
(
|
421
|
-
GetWeightsByNameReqOutput,
|
422
|
-
self.get_weights_by_name_communicator.handle_recv,
|
423
|
-
),
|
424
|
-
(
|
425
|
-
ReleaseMemoryOccupationReqOutput,
|
426
|
-
self.release_memory_occupation_communicator.handle_recv,
|
427
|
-
),
|
428
|
-
(
|
429
|
-
ResumeMemoryOccupationReqOutput,
|
430
|
-
self.resume_memory_occupation_communicator.handle_recv,
|
431
|
-
),
|
432
|
-
(
|
433
|
-
SlowDownReqOutput,
|
434
|
-
self.slow_down_communicator.handle_recv,
|
435
|
-
),
|
436
|
-
(
|
437
|
-
ClearHiCacheReqOutput,
|
438
|
-
self.clear_hicache_storage_communicator.handle_recv,
|
439
|
-
),
|
440
|
-
(
|
441
|
-
FlushCacheReqOutput,
|
442
|
-
self.flush_cache_communicator.handle_recv,
|
443
|
-
),
|
444
|
-
(
|
445
|
-
ProfileReqOutput,
|
446
|
-
self.profile_communicator.handle_recv,
|
447
|
-
),
|
448
359
|
(
|
449
360
|
FreezeGCReq,
|
450
361
|
lambda x: None,
|
451
362
|
), # For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it.
|
452
|
-
(
|
453
|
-
GetInternalStateReqOutput,
|
454
|
-
self.get_internal_state_communicator.handle_recv,
|
455
|
-
),
|
456
|
-
(
|
457
|
-
SetInternalStateReqOutput,
|
458
|
-
self.set_internal_state_communicator.handle_recv,
|
459
|
-
),
|
460
|
-
(
|
461
|
-
ExpertDistributionReqOutput,
|
462
|
-
self.expert_distribution_communicator.handle_recv,
|
463
|
-
),
|
464
|
-
(
|
465
|
-
LoRAUpdateResult,
|
466
|
-
self.update_lora_adapter_communicator.handle_recv,
|
467
|
-
),
|
468
363
|
(HealthCheckOutput, lambda x: None),
|
469
364
|
]
|
470
365
|
)
|
471
366
|
|
472
|
-
|
473
|
-
self.disaggregation_mode = DisaggregationMode(
|
474
|
-
self.server_args.disaggregation_mode
|
475
|
-
)
|
476
|
-
self.disaggregation_transfer_backend = TransferBackend(
|
477
|
-
self.server_args.disaggregation_transfer_backend
|
478
|
-
)
|
479
|
-
# Start kv boostrap server on prefill
|
480
|
-
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
481
|
-
# only start bootstrap server on prefill tm
|
482
|
-
kv_bootstrap_server_class = get_kv_class(
|
483
|
-
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
484
|
-
)
|
485
|
-
self.bootstrap_server = kv_bootstrap_server_class(
|
486
|
-
self.server_args.disaggregation_bootstrap_port
|
487
|
-
)
|
488
|
-
is_create_store = (
|
489
|
-
self.server_args.node_rank == 0
|
490
|
-
and self.server_args.disaggregation_transfer_backend == "ascend"
|
491
|
-
)
|
492
|
-
if is_create_store:
|
493
|
-
try:
|
494
|
-
from mf_adapter import create_config_store
|
495
|
-
|
496
|
-
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
|
497
|
-
create_config_store(ascend_url)
|
498
|
-
except Exception as e:
|
499
|
-
error_message = f"Failed create mf store, invalid ascend_url."
|
500
|
-
error_message += f" With exception {e}"
|
501
|
-
raise error_message
|
367
|
+
self.init_communicators(server_args)
|
502
368
|
|
503
369
|
async def generate_request(
|
504
370
|
self,
|
@@ -518,6 +384,9 @@ class TokenizerManager:
|
|
518
384
|
# If it's a single value, add worker_id prefix
|
519
385
|
obj.rid = f"{self.worker_id}_{obj.rid}"
|
520
386
|
|
387
|
+
if self.enable_trace:
|
388
|
+
self._trace_request_start(obj, created_time)
|
389
|
+
|
521
390
|
if self.log_requests:
|
522
391
|
max_length, skip_names, _ = self.log_request_metadata
|
523
392
|
logger.info(
|
@@ -543,6 +412,144 @@ class TokenizerManager:
|
|
543
412
|
):
|
544
413
|
yield response
|
545
414
|
|
415
|
+
def _detect_input_format(
|
416
|
+
self, texts: Union[str, List[str]], is_cross_encoder: bool
|
417
|
+
) -> str:
|
418
|
+
"""Detect the format of input texts for proper tokenization handling.
|
419
|
+
|
420
|
+
Returns:
|
421
|
+
- "single_string": Regular single text like "Hello world"
|
422
|
+
- "batch_strings": Regular batch like ["Hello", "World"]
|
423
|
+
- "cross_encoder_pairs": Cross-encoder pairs like [["query", "document"]]
|
424
|
+
"""
|
425
|
+
if isinstance(texts, str):
|
426
|
+
return "single_string"
|
427
|
+
|
428
|
+
if (
|
429
|
+
is_cross_encoder
|
430
|
+
and len(texts) > 0
|
431
|
+
and isinstance(texts[0], list)
|
432
|
+
and len(texts[0]) == 2
|
433
|
+
):
|
434
|
+
return "cross_encoder_pairs"
|
435
|
+
|
436
|
+
return "batch_strings"
|
437
|
+
|
438
|
+
def _prepare_tokenizer_input(
|
439
|
+
self, texts: Union[str, List[str]], input_format: str
|
440
|
+
) -> Union[List[str], List[List[str]]]:
|
441
|
+
"""Prepare input for the tokenizer based on detected format."""
|
442
|
+
if input_format == "single_string":
|
443
|
+
return [texts] # Wrap single string for batch processing
|
444
|
+
elif input_format == "cross_encoder_pairs":
|
445
|
+
return texts # Already in correct format: [["query", "doc"]]
|
446
|
+
else: # batch_strings
|
447
|
+
return texts # Already in correct format: ["text1", "text2"]
|
448
|
+
|
449
|
+
def _extract_tokenizer_results(
|
450
|
+
self,
|
451
|
+
input_ids: List[List[int]],
|
452
|
+
token_type_ids: Optional[List[List[int]]],
|
453
|
+
input_format: str,
|
454
|
+
original_batch_size: int,
|
455
|
+
) -> Union[
|
456
|
+
Tuple[List[int], Optional[List[int]]],
|
457
|
+
Tuple[List[List[int]], Optional[List[List[int]]]],
|
458
|
+
]:
|
459
|
+
"""Extract results from tokenizer output based on input format."""
|
460
|
+
|
461
|
+
# For single inputs (string or single cross-encoder pair), extract first element
|
462
|
+
if (
|
463
|
+
input_format in ["single_string", "cross_encoder_pairs"]
|
464
|
+
and original_batch_size == 1
|
465
|
+
):
|
466
|
+
single_input_ids = input_ids[0] if input_ids else []
|
467
|
+
single_token_type_ids = token_type_ids[0] if token_type_ids else None
|
468
|
+
return single_input_ids, single_token_type_ids
|
469
|
+
|
470
|
+
# For true batches, return as-is
|
471
|
+
return input_ids, token_type_ids
|
472
|
+
|
473
|
+
async def _tokenize_texts(
|
474
|
+
self, texts: Union[str, List[str]], is_cross_encoder: bool = False
|
475
|
+
) -> Union[
|
476
|
+
Tuple[List[int], Optional[List[int]]],
|
477
|
+
Tuple[List[List[int]], Optional[List[List[int]]]],
|
478
|
+
]:
|
479
|
+
"""
|
480
|
+
Tokenize text(s) using the appropriate tokenizer strategy.
|
481
|
+
|
482
|
+
This method handles multiple input formats and chooses between async dynamic
|
483
|
+
batch tokenizer (for single texts only) and regular tokenizer.
|
484
|
+
|
485
|
+
Args:
|
486
|
+
texts: Text input in various formats:
|
487
|
+
|
488
|
+
Regular cases:
|
489
|
+
- Single string: "How are you?"
|
490
|
+
- Batch of strings: ["Hello", "World", "How are you?"]
|
491
|
+
|
492
|
+
Cross-encoder cases (sentence pairs for similarity/ranking):
|
493
|
+
- Single pair: [["query text", "document text"]]
|
494
|
+
- Multiple pairs: [["q1", "d1"], ["q2", "d2"], ["q3", "d3"]]
|
495
|
+
|
496
|
+
is_cross_encoder: Whether to return token_type_ids for cross-encoder models.
|
497
|
+
Enables proper handling of sentence pairs with segment IDs.
|
498
|
+
|
499
|
+
Returns:
|
500
|
+
Single input cases:
|
501
|
+
Tuple[List[int], Optional[List[int]]]: (input_ids, token_type_ids)
|
502
|
+
Example: ([101, 2129, 102], [0, 0, 0]) for single text
|
503
|
+
Example: ([101, 2129, 102, 4068, 102], [0, 0, 0, 1, 1]) for cross-encoder pair
|
504
|
+
|
505
|
+
Batch input cases:
|
506
|
+
Tuple[List[List[int]], Optional[List[List[int]]]]: (batch_input_ids, batch_token_type_ids)
|
507
|
+
Example: ([[101, 2129, 102], [101, 4068, 102]], None) for regular batch
|
508
|
+
|
509
|
+
Note: token_type_ids is None unless is_cross_encoder=True.
|
510
|
+
"""
|
511
|
+
if not texts or self.tokenizer is None:
|
512
|
+
raise ValueError("texts cannot be empty and tokenizer must be initialized")
|
513
|
+
|
514
|
+
# Step 1: Detect input format and prepare for tokenization
|
515
|
+
input_format = self._detect_input_format(texts, is_cross_encoder)
|
516
|
+
tokenizer_input = self._prepare_tokenizer_input(texts, input_format)
|
517
|
+
original_batch_size = len(texts) if not isinstance(texts, str) else 1
|
518
|
+
|
519
|
+
# Step 2: Set up tokenizer arguments
|
520
|
+
tokenizer_kwargs = (
|
521
|
+
{"return_token_type_ids": is_cross_encoder} if is_cross_encoder else {}
|
522
|
+
)
|
523
|
+
|
524
|
+
# Step 3: Choose tokenization strategy
|
525
|
+
use_async_tokenizer = (
|
526
|
+
self.async_dynamic_batch_tokenizer is not None
|
527
|
+
and input_format == "single_string"
|
528
|
+
)
|
529
|
+
|
530
|
+
if use_async_tokenizer:
|
531
|
+
logger.debug("Using async dynamic batch tokenizer for single text")
|
532
|
+
result = await self.async_dynamic_batch_tokenizer.encode(
|
533
|
+
tokenizer_input[0], **tokenizer_kwargs
|
534
|
+
)
|
535
|
+
# Convert to batch format for consistency
|
536
|
+
input_ids = [result["input_ids"]]
|
537
|
+
token_type_ids = (
|
538
|
+
[result["token_type_ids"]]
|
539
|
+
if is_cross_encoder and result.get("token_type_ids")
|
540
|
+
else None
|
541
|
+
)
|
542
|
+
else:
|
543
|
+
logger.debug(f"Using regular tokenizer for {len(tokenizer_input)} inputs")
|
544
|
+
encoded = self.tokenizer(tokenizer_input, **tokenizer_kwargs)
|
545
|
+
input_ids = encoded["input_ids"]
|
546
|
+
token_type_ids = encoded.get("token_type_ids") if is_cross_encoder else None
|
547
|
+
|
548
|
+
# Step 4: Extract results based on input format
|
549
|
+
return self._extract_tokenizer_results(
|
550
|
+
input_ids, token_type_ids, input_format, original_batch_size
|
551
|
+
)
|
552
|
+
|
546
553
|
async def _tokenize_one_request(
|
547
554
|
self,
|
548
555
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
@@ -573,14 +580,10 @@ class TokenizerManager:
|
|
573
580
|
"accept text prompts. Please provide input_ids or re-initialize "
|
574
581
|
"the engine with skip_tokenizer_init=False."
|
575
582
|
)
|
576
|
-
encoded = self.tokenizer(
|
577
|
-
input_text, return_token_type_ids=is_cross_encoder_request
|
578
|
-
)
|
579
583
|
|
580
|
-
input_ids =
|
581
|
-
|
582
|
-
|
583
|
-
token_type_ids = encoded.get("token_type_ids", [None])[0]
|
584
|
+
input_ids, token_type_ids = await self._tokenize_texts(
|
585
|
+
input_text, is_cross_encoder_request
|
586
|
+
)
|
584
587
|
|
585
588
|
if self.mm_processor and obj.contains_mm_input():
|
586
589
|
if not isinstance(obj.image_data, list):
|
@@ -600,6 +603,7 @@ class TokenizerManager:
|
|
600
603
|
mm_inputs = None
|
601
604
|
|
602
605
|
self._validate_one_request(obj, input_ids)
|
606
|
+
trace_slice_end("tokenize", obj.rid)
|
603
607
|
return self._create_tokenized_object(
|
604
608
|
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
|
605
609
|
)
|
@@ -612,6 +616,7 @@ class TokenizerManager:
|
|
612
616
|
_max_req_len = self.context_len
|
613
617
|
|
614
618
|
input_token_num = len(input_ids) if input_ids is not None else 0
|
619
|
+
input_token_num += self.reserve_input_token_num
|
615
620
|
if input_token_num >= self.context_len:
|
616
621
|
if self.server_args.allow_auto_truncate:
|
617
622
|
logger.warning(
|
@@ -674,7 +679,7 @@ class TokenizerManager:
|
|
674
679
|
):
|
675
680
|
raise ValueError(
|
676
681
|
"The server is not configured to enable custom logit processor. "
|
677
|
-
"Please set `--enable-custom-
|
682
|
+
"Please set `--enable-custom-logit-processor` to enable this feature."
|
678
683
|
)
|
679
684
|
|
680
685
|
def _validate_input_ids_in_vocab(
|
@@ -713,7 +718,6 @@ class TokenizerManager:
|
|
713
718
|
)
|
714
719
|
|
715
720
|
tokenized_obj = TokenizedGenerateReqInput(
|
716
|
-
obj.rid,
|
717
721
|
input_text,
|
718
722
|
input_ids,
|
719
723
|
mm_inputs,
|
@@ -723,6 +727,7 @@ class TokenizerManager:
|
|
723
727
|
obj.top_logprobs_num,
|
724
728
|
obj.token_ids_logprob,
|
725
729
|
obj.stream,
|
730
|
+
rid=obj.rid,
|
726
731
|
bootstrap_host=obj.bootstrap_host,
|
727
732
|
bootstrap_port=obj.bootstrap_port,
|
728
733
|
bootstrap_room=obj.bootstrap_room,
|
@@ -732,15 +737,18 @@ class TokenizerManager:
|
|
732
737
|
custom_logit_processor=obj.custom_logit_processor,
|
733
738
|
return_hidden_states=obj.return_hidden_states,
|
734
739
|
data_parallel_rank=obj.data_parallel_rank,
|
740
|
+
priority=obj.priority,
|
741
|
+
extra_key=obj.extra_key,
|
735
742
|
)
|
736
743
|
elif isinstance(obj, EmbeddingReqInput):
|
737
744
|
tokenized_obj = TokenizedEmbeddingReqInput(
|
738
|
-
obj.rid,
|
739
745
|
input_text,
|
740
746
|
input_ids,
|
741
747
|
mm_inputs,
|
742
748
|
token_type_ids,
|
743
749
|
sampling_params,
|
750
|
+
rid=obj.rid,
|
751
|
+
priority=obj.priority,
|
744
752
|
)
|
745
753
|
|
746
754
|
return tokenized_obj
|
@@ -755,19 +763,30 @@ class TokenizerManager:
|
|
755
763
|
requests = [obj[i] for i in range(batch_size)]
|
756
764
|
texts = [req.text for req in requests]
|
757
765
|
|
758
|
-
#
|
759
|
-
|
760
|
-
|
766
|
+
# Check if any request is a cross-encoder request
|
767
|
+
is_cross_encoder_request = any(
|
768
|
+
isinstance(req, EmbeddingReqInput) and req.is_cross_encoder_request
|
769
|
+
for req in requests
|
770
|
+
)
|
771
|
+
|
772
|
+
# Batch tokenize all texts using unified method
|
773
|
+
input_ids_list, token_type_ids_list = await self._tokenize_texts(
|
774
|
+
texts, is_cross_encoder_request
|
775
|
+
)
|
761
776
|
|
762
777
|
# Process all requests
|
763
778
|
tokenized_objs = []
|
764
779
|
for i, req in enumerate(requests):
|
765
780
|
self._validate_one_request(obj[i], input_ids_list[i])
|
781
|
+
token_type_ids = (
|
782
|
+
token_type_ids_list[i] if token_type_ids_list is not None else None
|
783
|
+
)
|
766
784
|
tokenized_objs.append(
|
767
785
|
self._create_tokenized_object(
|
768
|
-
req, req.text, input_ids_list[i], None, None
|
786
|
+
req, req.text, input_ids_list[i], None, None, token_type_ids
|
769
787
|
)
|
770
788
|
)
|
789
|
+
trace_slice_end("tokenize", req.rid)
|
771
790
|
logger.debug(f"Completed batch processing for {batch_size} requests")
|
772
791
|
return tokenized_objs
|
773
792
|
|
@@ -795,9 +814,12 @@ class TokenizerManager:
|
|
795
814
|
tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
|
796
815
|
created_time: Optional[float] = None,
|
797
816
|
):
|
817
|
+
trace_slice_start("dispatch", obj.rid)
|
818
|
+
tokenized_obj.trace_context = trace_get_proc_propagate_context(obj.rid)
|
798
819
|
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
799
820
|
state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
|
800
821
|
self.rid_to_state[obj.rid] = state
|
822
|
+
trace_slice_end("dispatch", obj.rid, thread_finish_flag=True)
|
801
823
|
return state
|
802
824
|
|
803
825
|
def _send_batch_request(
|
@@ -1015,73 +1037,16 @@ class TokenizerManager:
|
|
1015
1037
|
except StopAsyncIteration:
|
1016
1038
|
pass
|
1017
1039
|
|
1018
|
-
async def flush_cache(self) -> FlushCacheReqOutput:
|
1019
|
-
return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
|
1020
|
-
|
1021
|
-
async def clear_hicache_storage(self) -> ClearHiCacheReqOutput:
|
1022
|
-
"""Clear the hierarchical cache storage."""
|
1023
|
-
# Delegate to the scheduler to handle HiCacheStorage clearing
|
1024
|
-
return (await self.clear_hicache_storage_communicator(ClearHiCacheReqInput()))[
|
1025
|
-
0
|
1026
|
-
]
|
1027
|
-
|
1028
1040
|
def abort_request(self, rid: str = "", abort_all: bool = False):
|
1029
1041
|
if not abort_all and rid not in self.rid_to_state:
|
1030
1042
|
return
|
1031
|
-
req = AbortReq(rid, abort_all)
|
1043
|
+
req = AbortReq(rid=rid, abort_all=abort_all)
|
1032
1044
|
self.send_to_scheduler.send_pyobj(req)
|
1033
|
-
|
1034
1045
|
if self.enable_metrics:
|
1035
|
-
|
1036
|
-
|
1037
|
-
|
1038
|
-
|
1039
|
-
output_dir: Optional[str] = None,
|
1040
|
-
start_step: Optional[int] = None,
|
1041
|
-
num_steps: Optional[int] = None,
|
1042
|
-
activities: Optional[List[str]] = None,
|
1043
|
-
with_stack: Optional[bool] = None,
|
1044
|
-
record_shapes: Optional[bool] = None,
|
1045
|
-
profile_by_stage: bool = False,
|
1046
|
-
):
|
1047
|
-
self.auto_create_handle_loop()
|
1048
|
-
env_with_stack: bool = get_bool_env_var("SGLANG_PROFILE_WITH_STACK", "true")
|
1049
|
-
with_stack = False if with_stack is False or env_with_stack is False else True
|
1050
|
-
req = ProfileReq(
|
1051
|
-
type=ProfileReqType.START_PROFILE,
|
1052
|
-
output_dir=output_dir,
|
1053
|
-
start_step=start_step,
|
1054
|
-
num_steps=num_steps,
|
1055
|
-
activities=activities,
|
1056
|
-
with_stack=with_stack,
|
1057
|
-
record_shapes=record_shapes,
|
1058
|
-
profile_by_stage=profile_by_stage,
|
1059
|
-
profile_id=str(time.time()),
|
1060
|
-
)
|
1061
|
-
return await self._execute_profile(req)
|
1062
|
-
|
1063
|
-
async def stop_profile(self):
|
1064
|
-
self.auto_create_handle_loop()
|
1065
|
-
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
|
1066
|
-
return await self._execute_profile(req)
|
1067
|
-
|
1068
|
-
async def _execute_profile(self, req: ProfileReq):
|
1069
|
-
result = (await self.profile_communicator(req))[0]
|
1070
|
-
if not result.success:
|
1071
|
-
raise RuntimeError(result.message)
|
1072
|
-
return result
|
1073
|
-
|
1074
|
-
async def start_expert_distribution_record(self):
|
1075
|
-
self.auto_create_handle_loop()
|
1076
|
-
await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
|
1077
|
-
|
1078
|
-
async def stop_expert_distribution_record(self):
|
1079
|
-
self.auto_create_handle_loop()
|
1080
|
-
await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
|
1081
|
-
|
1082
|
-
async def dump_expert_distribution_record(self):
|
1083
|
-
self.auto_create_handle_loop()
|
1084
|
-
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
|
1046
|
+
# TODO: also use custom_labels from the request
|
1047
|
+
self.metrics_collector.observe_one_aborted_request(
|
1048
|
+
self.metrics_collector.labels
|
1049
|
+
)
|
1085
1050
|
|
1086
1051
|
async def pause_generation(self):
|
1087
1052
|
async with self.is_pause_cond:
|
@@ -1118,7 +1083,7 @@ class TokenizerManager:
|
|
1118
1083
|
self, obj: UpdateWeightFromDiskReqInput
|
1119
1084
|
) -> Tuple[bool, str]:
|
1120
1085
|
if self.server_args.tokenizer_worker_num > 1:
|
1121
|
-
obj =
|
1086
|
+
obj = MultiTokenizerWrapper(self.worker_id, obj)
|
1122
1087
|
self.send_to_scheduler.send_pyobj(obj)
|
1123
1088
|
self.model_update_result = asyncio.Future()
|
1124
1089
|
if self.server_args.dp_size == 1:
|
@@ -1143,291 +1108,6 @@ class TokenizerManager:
|
|
1143
1108
|
all_paused_requests = [r.num_paused_requests for r in result]
|
1144
1109
|
return all_success, all_message, all_paused_requests
|
1145
1110
|
|
1146
|
-
async def init_weights_update_group(
|
1147
|
-
self,
|
1148
|
-
obj: InitWeightsUpdateGroupReqInput,
|
1149
|
-
request: Optional[fastapi.Request] = None,
|
1150
|
-
) -> Tuple[bool, str]:
|
1151
|
-
self.auto_create_handle_loop()
|
1152
|
-
assert (
|
1153
|
-
self.server_args.dp_size == 1
|
1154
|
-
), "dp_size must be 1 for init parameter update group"
|
1155
|
-
result = (await self.init_weights_update_group_communicator(obj))[0]
|
1156
|
-
return result.success, result.message
|
1157
|
-
|
1158
|
-
async def update_weights_from_distributed(
|
1159
|
-
self,
|
1160
|
-
obj: UpdateWeightsFromDistributedReqInput,
|
1161
|
-
request: Optional[fastapi.Request] = None,
|
1162
|
-
) -> Tuple[bool, str]:
|
1163
|
-
self.auto_create_handle_loop()
|
1164
|
-
assert (
|
1165
|
-
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
1166
|
-
), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
|
1167
|
-
|
1168
|
-
if obj.abort_all_requests:
|
1169
|
-
self.abort_request(abort_all=True)
|
1170
|
-
|
1171
|
-
# This means that weight sync
|
1172
|
-
# cannot run while requests are in progress.
|
1173
|
-
async with self.model_update_lock.writer_lock:
|
1174
|
-
result = (await self.update_weights_from_distributed_communicator(obj))[0]
|
1175
|
-
return result.success, result.message
|
1176
|
-
|
1177
|
-
async def update_weights_from_tensor(
|
1178
|
-
self,
|
1179
|
-
obj: UpdateWeightsFromTensorReqInput,
|
1180
|
-
request: Optional[fastapi.Request] = None,
|
1181
|
-
) -> Tuple[bool, str]:
|
1182
|
-
self.auto_create_handle_loop()
|
1183
|
-
assert (
|
1184
|
-
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
1185
|
-
), "dp_size must be 1 or dp attention must be enabled for update weights from tensor"
|
1186
|
-
|
1187
|
-
if obj.abort_all_requests:
|
1188
|
-
self.abort_request(abort_all=True)
|
1189
|
-
|
1190
|
-
# This means that weight sync
|
1191
|
-
# cannot run while requests are in progress.
|
1192
|
-
async with self.model_update_lock.writer_lock:
|
1193
|
-
result = (await self.update_weights_from_tensor_communicator(obj))[0]
|
1194
|
-
return result.success, result.message
|
1195
|
-
|
1196
|
-
async def load_lora_adapter(
|
1197
|
-
self,
|
1198
|
-
obj: LoadLoRAAdapterReqInput,
|
1199
|
-
_: Optional[fastapi.Request] = None,
|
1200
|
-
) -> LoadLoRAAdapterReqOutput:
|
1201
|
-
self.auto_create_handle_loop()
|
1202
|
-
|
1203
|
-
try:
|
1204
|
-
if not self.server_args.enable_lora:
|
1205
|
-
raise ValueError(
|
1206
|
-
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
1207
|
-
)
|
1208
|
-
|
1209
|
-
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
1210
|
-
# with dp_size > 1.
|
1211
|
-
assert (
|
1212
|
-
self.server_args.dp_size == 1
|
1213
|
-
), "dp_size must be 1 for dynamic lora loading"
|
1214
|
-
logger.info(
|
1215
|
-
"Start load Lora adapter. Lora name=%s, path=%s",
|
1216
|
-
obj.lora_name,
|
1217
|
-
obj.lora_path,
|
1218
|
-
)
|
1219
|
-
|
1220
|
-
async with self.lora_update_lock:
|
1221
|
-
if (
|
1222
|
-
self.server_args.max_loaded_loras is not None
|
1223
|
-
and self.lora_registry.num_registered_loras
|
1224
|
-
>= self.server_args.max_loaded_loras
|
1225
|
-
):
|
1226
|
-
raise ValueError(
|
1227
|
-
f"Cannot load LoRA adapter {obj.lora_name} at path {obj.lora_path}. "
|
1228
|
-
f"Maximum number of loaded LoRA adapters is {self.server_args.max_loaded_loras}. "
|
1229
|
-
"Please unload some LoRA adapters before loading new ones."
|
1230
|
-
)
|
1231
|
-
|
1232
|
-
# Generate new uniquely identifiable LoRARef object.
|
1233
|
-
new_adapter = LoRARef(
|
1234
|
-
lora_name=obj.lora_name,
|
1235
|
-
lora_path=obj.lora_path,
|
1236
|
-
pinned=obj.pinned,
|
1237
|
-
)
|
1238
|
-
|
1239
|
-
# Trigger the actual loading operation at the backend processes.
|
1240
|
-
obj.lora_id = new_adapter.lora_id
|
1241
|
-
result = (await self.update_lora_adapter_communicator(obj))[0]
|
1242
|
-
|
1243
|
-
# Register the LoRA adapter only after loading is successful.
|
1244
|
-
if result.success:
|
1245
|
-
await self.lora_registry.register(new_adapter)
|
1246
|
-
|
1247
|
-
return result
|
1248
|
-
except ValueError as e:
|
1249
|
-
return LoadLoRAAdapterReqOutput(
|
1250
|
-
success=False,
|
1251
|
-
error_message=str(e),
|
1252
|
-
)
|
1253
|
-
|
1254
|
-
async def unload_lora_adapter(
|
1255
|
-
self,
|
1256
|
-
obj: UnloadLoRAAdapterReqInput,
|
1257
|
-
_: Optional[fastapi.Request] = None,
|
1258
|
-
) -> UnloadLoRAAdapterReqOutput:
|
1259
|
-
self.auto_create_handle_loop()
|
1260
|
-
|
1261
|
-
try:
|
1262
|
-
if not self.server_args.enable_lora:
|
1263
|
-
raise ValueError(
|
1264
|
-
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
1265
|
-
)
|
1266
|
-
|
1267
|
-
assert (
|
1268
|
-
obj.lora_name is not None
|
1269
|
-
), "lora_name must be provided to unload LoRA adapter"
|
1270
|
-
|
1271
|
-
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
1272
|
-
# with dp_size > 1.
|
1273
|
-
assert (
|
1274
|
-
self.server_args.dp_size == 1
|
1275
|
-
), "dp_size must be 1 for dynamic lora loading"
|
1276
|
-
logger.info(
|
1277
|
-
"Start unload Lora adapter. Lora name=%s",
|
1278
|
-
obj.lora_name,
|
1279
|
-
)
|
1280
|
-
|
1281
|
-
async with self.lora_update_lock:
|
1282
|
-
# Unregister the LoRA adapter from the registry to stop new requests for this adapter
|
1283
|
-
# from being started.
|
1284
|
-
lora_id = await self.lora_registry.unregister(obj.lora_name)
|
1285
|
-
obj.lora_id = lora_id
|
1286
|
-
|
1287
|
-
# Initiate the actual unloading operation at the backend processes only after all
|
1288
|
-
# ongoing requests using this LoRA adapter are finished.
|
1289
|
-
await self.lora_registry.wait_for_unload(lora_id)
|
1290
|
-
result = (await self.update_lora_adapter_communicator(obj))[0]
|
1291
|
-
|
1292
|
-
return result
|
1293
|
-
except ValueError as e:
|
1294
|
-
return UnloadLoRAAdapterReqOutput(success=False, error_message=str(e))
|
1295
|
-
|
1296
|
-
async def get_weights_by_name(
|
1297
|
-
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
1298
|
-
):
|
1299
|
-
self.auto_create_handle_loop()
|
1300
|
-
results = await self.get_weights_by_name_communicator(obj)
|
1301
|
-
all_parameters = [r.parameter for r in results]
|
1302
|
-
if self.server_args.dp_size == 1:
|
1303
|
-
return all_parameters[0]
|
1304
|
-
else:
|
1305
|
-
return all_parameters
|
1306
|
-
|
1307
|
-
async def release_memory_occupation(
|
1308
|
-
self,
|
1309
|
-
obj: ReleaseMemoryOccupationReqInput,
|
1310
|
-
request: Optional[fastapi.Request] = None,
|
1311
|
-
):
|
1312
|
-
self.auto_create_handle_loop()
|
1313
|
-
await self.release_memory_occupation_communicator(obj)
|
1314
|
-
|
1315
|
-
async def resume_memory_occupation(
|
1316
|
-
self,
|
1317
|
-
obj: ResumeMemoryOccupationReqInput,
|
1318
|
-
request: Optional[fastapi.Request] = None,
|
1319
|
-
):
|
1320
|
-
self.auto_create_handle_loop()
|
1321
|
-
await self.resume_memory_occupation_communicator(obj)
|
1322
|
-
|
1323
|
-
async def slow_down(
|
1324
|
-
self,
|
1325
|
-
obj: SlowDownReqInput,
|
1326
|
-
request: Optional[fastapi.Request] = None,
|
1327
|
-
):
|
1328
|
-
self.auto_create_handle_loop()
|
1329
|
-
await self.slow_down_communicator(obj)
|
1330
|
-
|
1331
|
-
async def open_session(
|
1332
|
-
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
1333
|
-
):
|
1334
|
-
self.auto_create_handle_loop()
|
1335
|
-
|
1336
|
-
if obj.session_id is None:
|
1337
|
-
obj.session_id = uuid.uuid4().hex
|
1338
|
-
elif obj.session_id in self.session_futures:
|
1339
|
-
return None
|
1340
|
-
|
1341
|
-
if self.server_args.tokenizer_worker_num > 1:
|
1342
|
-
obj = MultiTokenizerWarpper(self.worker_id, obj)
|
1343
|
-
self.send_to_scheduler.send_pyobj(obj)
|
1344
|
-
|
1345
|
-
self.session_futures[obj.session_id] = asyncio.Future()
|
1346
|
-
session_id = await self.session_futures[obj.session_id]
|
1347
|
-
del self.session_futures[obj.session_id]
|
1348
|
-
return session_id
|
1349
|
-
|
1350
|
-
async def close_session(
|
1351
|
-
self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
|
1352
|
-
):
|
1353
|
-
await self.send_to_scheduler.send_pyobj(obj)
|
1354
|
-
|
1355
|
-
async def get_internal_state(self) -> List[Dict[Any, Any]]:
|
1356
|
-
req = GetInternalStateReq()
|
1357
|
-
responses: List[GetInternalStateReqOutput] = (
|
1358
|
-
await self.get_internal_state_communicator(req)
|
1359
|
-
)
|
1360
|
-
# Many DP ranks
|
1361
|
-
return [res.internal_state for res in responses]
|
1362
|
-
|
1363
|
-
async def set_internal_state(self, obj: SetInternalStateReq) -> List[bool]:
|
1364
|
-
responses: List[SetInternalStateReqOutput] = (
|
1365
|
-
await self.set_internal_state_communicator(obj)
|
1366
|
-
)
|
1367
|
-
return [res.updated for res in responses]
|
1368
|
-
|
1369
|
-
async def get_load(self) -> dict:
|
1370
|
-
# TODO(lsyin): fake load report server
|
1371
|
-
if not self.current_load_lock.locked():
|
1372
|
-
async with self.current_load_lock:
|
1373
|
-
internal_state = await self.get_internal_state()
|
1374
|
-
self.current_load = internal_state[0]["load"]
|
1375
|
-
return {"load": self.current_load}
|
1376
|
-
|
1377
|
-
def get_log_request_metadata(self):
|
1378
|
-
max_length = None
|
1379
|
-
skip_names = None
|
1380
|
-
out_skip_names = None
|
1381
|
-
if self.log_requests:
|
1382
|
-
if self.log_requests_level == 0:
|
1383
|
-
max_length = 1 << 30
|
1384
|
-
skip_names = set(
|
1385
|
-
[
|
1386
|
-
"text",
|
1387
|
-
"input_ids",
|
1388
|
-
"input_embeds",
|
1389
|
-
"image_data",
|
1390
|
-
"audio_data",
|
1391
|
-
"lora_path",
|
1392
|
-
"sampling_params",
|
1393
|
-
]
|
1394
|
-
)
|
1395
|
-
out_skip_names = set(
|
1396
|
-
[
|
1397
|
-
"text",
|
1398
|
-
"output_ids",
|
1399
|
-
"embedding",
|
1400
|
-
]
|
1401
|
-
)
|
1402
|
-
elif self.log_requests_level == 1:
|
1403
|
-
max_length = 1 << 30
|
1404
|
-
skip_names = set(
|
1405
|
-
[
|
1406
|
-
"text",
|
1407
|
-
"input_ids",
|
1408
|
-
"input_embeds",
|
1409
|
-
"image_data",
|
1410
|
-
"audio_data",
|
1411
|
-
"lora_path",
|
1412
|
-
]
|
1413
|
-
)
|
1414
|
-
out_skip_names = set(
|
1415
|
-
[
|
1416
|
-
"text",
|
1417
|
-
"output_ids",
|
1418
|
-
"embedding",
|
1419
|
-
]
|
1420
|
-
)
|
1421
|
-
elif self.log_requests_level == 2:
|
1422
|
-
max_length = 2048
|
1423
|
-
elif self.log_requests_level == 3:
|
1424
|
-
max_length = 1 << 30
|
1425
|
-
else:
|
1426
|
-
raise ValueError(
|
1427
|
-
f"Invalid --log-requests-level: {self.log_requests_level=}"
|
1428
|
-
)
|
1429
|
-
return max_length, skip_names, out_skip_names
|
1430
|
-
|
1431
1111
|
def configure_logging(self, obj: ConfigureLoggingReq):
|
1432
1112
|
if obj.log_requests is not None:
|
1433
1113
|
self.log_requests = obj.log_requests
|
@@ -1492,6 +1172,9 @@ class TokenizerManager:
|
|
1492
1172
|
self.asyncio_tasks.add(
|
1493
1173
|
loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
|
1494
1174
|
)
|
1175
|
+
self.asyncio_tasks.add(
|
1176
|
+
loop.create_task(print_exception_wrapper(self.watch_load_thread))
|
1177
|
+
)
|
1495
1178
|
|
1496
1179
|
def dump_requests_before_crash(self):
|
1497
1180
|
if self.crash_dump_performed:
|
@@ -1583,12 +1266,12 @@ class TokenizerManager:
|
|
1583
1266
|
# Drain requests
|
1584
1267
|
while True:
|
1585
1268
|
remain_num_req = len(self.rid_to_state)
|
1269
|
+
remaining_rids = list(self.rid_to_state.keys())
|
1586
1270
|
|
1587
1271
|
if self.server_status == ServerStatus.UnHealthy:
|
1588
1272
|
# if health check failed, we should exit immediately
|
1589
1273
|
logger.error(
|
1590
|
-
"Signal SIGTERM received while health check failed.
|
1591
|
-
remain_num_req,
|
1274
|
+
"Signal SIGTERM received while health check failed. Force exiting."
|
1592
1275
|
)
|
1593
1276
|
self.dump_requests_before_crash()
|
1594
1277
|
break
|
@@ -1596,13 +1279,12 @@ class TokenizerManager:
|
|
1596
1279
|
elif get_bool_env_var("SGL_FORCE_SHUTDOWN"):
|
1597
1280
|
# if force shutdown flag set, exit immediately
|
1598
1281
|
logger.error(
|
1599
|
-
"Signal SIGTERM received while force shutdown flag set. Force exiting
|
1600
|
-
remain_num_req,
|
1282
|
+
"Signal SIGTERM received while force shutdown flag set. Force exiting."
|
1601
1283
|
)
|
1602
1284
|
break
|
1603
1285
|
|
1604
1286
|
logger.info(
|
1605
|
-
f"Gracefully exiting...
|
1287
|
+
f"Gracefully exiting... Remaining number of requests {remain_num_req}. Remaining requests {remaining_rids=}."
|
1606
1288
|
)
|
1607
1289
|
if remain_num_req > 0:
|
1608
1290
|
await asyncio.sleep(5)
|
@@ -1623,7 +1305,10 @@ class TokenizerManager:
|
|
1623
1305
|
def _handle_batch_output(
|
1624
1306
|
self,
|
1625
1307
|
recv_obj: Union[
|
1626
|
-
|
1308
|
+
BatchStrOutput,
|
1309
|
+
BatchEmbeddingOutput,
|
1310
|
+
BatchMultimodalOutput,
|
1311
|
+
BatchTokenIDOutput,
|
1627
1312
|
],
|
1628
1313
|
):
|
1629
1314
|
for i, rid in enumerate(recv_obj.rids):
|
@@ -1657,7 +1342,7 @@ class TokenizerManager:
|
|
1657
1342
|
i,
|
1658
1343
|
)
|
1659
1344
|
|
1660
|
-
if not isinstance(recv_obj,
|
1345
|
+
if not isinstance(recv_obj, BatchEmbeddingOutput):
|
1661
1346
|
meta_info.update(
|
1662
1347
|
{
|
1663
1348
|
"completion_tokens": recv_obj.completion_tokens[i],
|
@@ -1668,7 +1353,7 @@ class TokenizerManager:
|
|
1668
1353
|
if getattr(recv_obj, "output_hidden_states", None):
|
1669
1354
|
meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
|
1670
1355
|
|
1671
|
-
if isinstance(recv_obj,
|
1356
|
+
if isinstance(recv_obj, BatchStrOutput):
|
1672
1357
|
state.text += recv_obj.output_strs[i]
|
1673
1358
|
if state.obj.stream:
|
1674
1359
|
state.output_ids.extend(recv_obj.output_ids[i])
|
@@ -1683,7 +1368,7 @@ class TokenizerManager:
|
|
1683
1368
|
"output_ids": output_token_ids,
|
1684
1369
|
"meta_info": meta_info,
|
1685
1370
|
}
|
1686
|
-
elif isinstance(recv_obj,
|
1371
|
+
elif isinstance(recv_obj, BatchTokenIDOutput):
|
1687
1372
|
if self.server_args.stream_output and state.obj.stream:
|
1688
1373
|
state.output_ids.extend(recv_obj.output_ids[i])
|
1689
1374
|
output_token_ids = state.output_ids[state.last_output_offset :]
|
@@ -1696,10 +1381,10 @@ class TokenizerManager:
|
|
1696
1381
|
"output_ids": output_token_ids,
|
1697
1382
|
"meta_info": meta_info,
|
1698
1383
|
}
|
1699
|
-
elif isinstance(recv_obj,
|
1384
|
+
elif isinstance(recv_obj, BatchMultimodalOutput):
|
1700
1385
|
raise NotImplementedError("BatchMultimodalOut not implemented")
|
1701
1386
|
else:
|
1702
|
-
assert isinstance(recv_obj,
|
1387
|
+
assert isinstance(recv_obj, BatchEmbeddingOutput)
|
1703
1388
|
out_dict = {
|
1704
1389
|
"embedding": recv_obj.embeddings[i],
|
1705
1390
|
"meta_info": meta_info,
|
@@ -1711,6 +1396,9 @@ class TokenizerManager:
|
|
1711
1396
|
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
1712
1397
|
state.finished_time = time.time()
|
1713
1398
|
meta_info["e2e_latency"] = state.finished_time - state.created_time
|
1399
|
+
|
1400
|
+
trace_req_finish(rid, ts=int(state.finished_time * 1e9))
|
1401
|
+
|
1714
1402
|
del self.rid_to_state[rid]
|
1715
1403
|
|
1716
1404
|
# Mark ongoing LoRA request as finished.
|
@@ -1735,7 +1423,7 @@ class TokenizerManager:
|
|
1735
1423
|
top_logprobs_num: int,
|
1736
1424
|
token_ids_logprob: List[int],
|
1737
1425
|
return_text_in_logprobs: bool,
|
1738
|
-
recv_obj:
|
1426
|
+
recv_obj: BatchStrOutput,
|
1739
1427
|
recv_obj_index: int,
|
1740
1428
|
):
|
1741
1429
|
if recv_obj.input_token_logprobs_val is None:
|
@@ -1853,13 +1541,19 @@ class TokenizerManager:
|
|
1853
1541
|
ret.append(None)
|
1854
1542
|
return ret
|
1855
1543
|
|
1856
|
-
def collect_metrics(self, state: ReqState, recv_obj:
|
1544
|
+
def collect_metrics(self, state: ReqState, recv_obj: BatchStrOutput, i: int):
|
1857
1545
|
completion_tokens = (
|
1858
1546
|
recv_obj.completion_tokens[i]
|
1859
1547
|
if getattr(recv_obj, "completion_tokens", None)
|
1860
1548
|
else 0
|
1861
1549
|
)
|
1862
1550
|
|
1551
|
+
custom_labels = getattr(state.obj, "custom_labels", None)
|
1552
|
+
labels = (
|
1553
|
+
{**self.metrics_collector.labels, **custom_labels}
|
1554
|
+
if custom_labels
|
1555
|
+
else self.metrics_collector.labels
|
1556
|
+
)
|
1863
1557
|
if (
|
1864
1558
|
state.first_token_time == 0.0
|
1865
1559
|
and self.disaggregation_mode != DisaggregationMode.PREFILL
|
@@ -1867,7 +1561,7 @@ class TokenizerManager:
|
|
1867
1561
|
state.first_token_time = state.last_time = time.time()
|
1868
1562
|
state.last_completion_tokens = completion_tokens
|
1869
1563
|
self.metrics_collector.observe_time_to_first_token(
|
1870
|
-
state.first_token_time - state.created_time
|
1564
|
+
labels, state.first_token_time - state.created_time
|
1871
1565
|
)
|
1872
1566
|
else:
|
1873
1567
|
num_new_tokens = completion_tokens - state.last_completion_tokens
|
@@ -1875,6 +1569,7 @@ class TokenizerManager:
|
|
1875
1569
|
new_time = time.time()
|
1876
1570
|
interval = new_time - state.last_time
|
1877
1571
|
self.metrics_collector.observe_inter_token_latency(
|
1572
|
+
labels,
|
1878
1573
|
interval,
|
1879
1574
|
num_new_tokens,
|
1880
1575
|
)
|
@@ -1889,6 +1584,7 @@ class TokenizerManager:
|
|
1889
1584
|
or state.obj.sampling_params.get("structural_tag", None)
|
1890
1585
|
)
|
1891
1586
|
self.metrics_collector.observe_one_finished_request(
|
1587
|
+
labels,
|
1892
1588
|
recv_obj.prompt_tokens[i],
|
1893
1589
|
completion_tokens,
|
1894
1590
|
recv_obj.cached_tokens[i],
|
@@ -1941,7 +1637,7 @@ class TokenizerManager:
|
|
1941
1637
|
|
1942
1638
|
asyncio.create_task(asyncio.to_thread(background_task))
|
1943
1639
|
|
1944
|
-
def _handle_abort_req(self, recv_obj):
|
1640
|
+
def _handle_abort_req(self, recv_obj: AbortReq):
|
1945
1641
|
if is_health_check_generate_req(recv_obj):
|
1946
1642
|
return
|
1947
1643
|
state = self.rid_to_state[recv_obj.rid]
|
@@ -1986,6 +1682,201 @@ class TokenizerManager:
|
|
1986
1682
|
if len(self.model_update_tmp) == self.server_args.dp_size:
|
1987
1683
|
self.model_update_result.set_result(self.model_update_tmp)
|
1988
1684
|
|
1685
|
+
def _initialize_multi_item_delimiter_text(self):
|
1686
|
+
"""Initialize multi-item delimiter text from token ID after tokenizer is loaded."""
|
1687
|
+
if (
|
1688
|
+
hasattr(self.server_args, "multi_item_scoring_delimiter")
|
1689
|
+
and self.server_args.multi_item_scoring_delimiter is not None
|
1690
|
+
and self.tokenizer is not None
|
1691
|
+
):
|
1692
|
+
try:
|
1693
|
+
self.multi_item_delimiter_text = self.tokenizer.decode(
|
1694
|
+
[self.server_args.multi_item_scoring_delimiter],
|
1695
|
+
skip_special_tokens=False,
|
1696
|
+
)
|
1697
|
+
except Exception as e:
|
1698
|
+
logger.warning(
|
1699
|
+
f"Failed to decode delimiter token {self.server_args.multi_item_scoring_delimiter}: {e}"
|
1700
|
+
)
|
1701
|
+
self.multi_item_delimiter_text = None
|
1702
|
+
|
1703
|
+
def _build_multi_item_token_sequence(
|
1704
|
+
self, query: List[int], items: List[List[int]], delimiter_token_id: int
|
1705
|
+
) -> List[int]:
|
1706
|
+
"""
|
1707
|
+
Build a single token sequence for multi-item scoring.
|
1708
|
+
Format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter>
|
1709
|
+
|
1710
|
+
Args:
|
1711
|
+
query: Query token IDs
|
1712
|
+
items: List of item token ID sequences
|
1713
|
+
delimiter_token_id: Token ID to use as delimiter
|
1714
|
+
|
1715
|
+
Returns:
|
1716
|
+
Combined token sequence
|
1717
|
+
"""
|
1718
|
+
combined_sequence = query[:] # Start with query
|
1719
|
+
|
1720
|
+
for item in items:
|
1721
|
+
combined_sequence.append(delimiter_token_id) # Add delimiter
|
1722
|
+
combined_sequence.extend(item) # Add item tokens
|
1723
|
+
|
1724
|
+
# Add final delimiter after the last item for logprob extraction
|
1725
|
+
combined_sequence.append(delimiter_token_id)
|
1726
|
+
|
1727
|
+
return combined_sequence
|
1728
|
+
|
1729
|
+
def _extract_logprobs_for_tokens(
|
1730
|
+
self, logprobs_data: List, label_token_ids: List[int]
|
1731
|
+
) -> Dict[int, float]:
|
1732
|
+
"""
|
1733
|
+
Extract logprobs for specified token IDs from logprobs data.
|
1734
|
+
|
1735
|
+
Args:
|
1736
|
+
logprobs_data: List of (logprob, token_id, text) tuples
|
1737
|
+
label_token_ids: Token IDs to extract logprobs for
|
1738
|
+
|
1739
|
+
Returns:
|
1740
|
+
Dictionary mapping token_id to logprob
|
1741
|
+
"""
|
1742
|
+
logprobs = {}
|
1743
|
+
if logprobs_data:
|
1744
|
+
for logprob, token_id, _ in logprobs_data:
|
1745
|
+
if token_id in label_token_ids:
|
1746
|
+
logprobs[token_id] = logprob
|
1747
|
+
return logprobs
|
1748
|
+
|
1749
|
+
def _convert_logprobs_to_scores(
|
1750
|
+
self,
|
1751
|
+
logprobs: Dict[int, float],
|
1752
|
+
label_token_ids: List[int],
|
1753
|
+
apply_softmax: bool,
|
1754
|
+
) -> List[float]:
|
1755
|
+
"""
|
1756
|
+
Convert logprobs dictionary to ordered score list.
|
1757
|
+
|
1758
|
+
Args:
|
1759
|
+
logprobs: Dictionary mapping token_id to logprob
|
1760
|
+
label_token_ids: Token IDs in desired order
|
1761
|
+
apply_softmax: Whether to apply softmax normalization
|
1762
|
+
|
1763
|
+
Returns:
|
1764
|
+
List of scores in the same order as label_token_ids
|
1765
|
+
"""
|
1766
|
+
score_list = [
|
1767
|
+
logprobs.get(token_id, float("-inf")) for token_id in label_token_ids
|
1768
|
+
]
|
1769
|
+
|
1770
|
+
if apply_softmax:
|
1771
|
+
score_list = torch.softmax(torch.tensor(score_list), dim=0).tolist()
|
1772
|
+
else:
|
1773
|
+
# Convert logprobs to probabilities if not using softmax
|
1774
|
+
score_list = [
|
1775
|
+
math.exp(x) if x != float("-inf") else 0.0 for x in score_list
|
1776
|
+
]
|
1777
|
+
|
1778
|
+
return score_list
|
1779
|
+
|
1780
|
+
def _process_multi_item_scoring_results(
|
1781
|
+
self,
|
1782
|
+
results: Any,
|
1783
|
+
items: List,
|
1784
|
+
label_token_ids: List[int],
|
1785
|
+
apply_softmax: bool,
|
1786
|
+
batch_request=None,
|
1787
|
+
) -> List[List[float]]:
|
1788
|
+
"""
|
1789
|
+
Process results from multi-item scoring request.
|
1790
|
+
Extracts logprobs at delimiter positions from input_token_ids_logprobs.
|
1791
|
+
|
1792
|
+
Args:
|
1793
|
+
results: Results from generate_request
|
1794
|
+
items: List of items being scored
|
1795
|
+
label_token_ids: Token IDs to extract scores for
|
1796
|
+
apply_softmax: Whether to apply softmax normalization
|
1797
|
+
batch_request: The original batch request containing input sequence
|
1798
|
+
|
1799
|
+
Returns:
|
1800
|
+
List of score lists, one for each item
|
1801
|
+
"""
|
1802
|
+
single_result = results[0] if isinstance(results, list) else results
|
1803
|
+
|
1804
|
+
# For multi-item scoring, logprobs are in input_token_ids_logprobs
|
1805
|
+
input_logprobs = single_result["meta_info"].get("input_token_ids_logprobs", [])
|
1806
|
+
|
1807
|
+
if not input_logprobs:
|
1808
|
+
raise RuntimeError(
|
1809
|
+
f"input_token_ids_logprobs is empty for multi-item scoring request {single_result['meta_info'].get('id', '<unknown>')}. "
|
1810
|
+
"This indicates token_ids_logprobs were not computed properly for Mutil Item Scoring."
|
1811
|
+
)
|
1812
|
+
|
1813
|
+
scores = []
|
1814
|
+
num_items = len(items) if isinstance(items, list) else 1
|
1815
|
+
|
1816
|
+
# Check if we have the expected number of logprobs
|
1817
|
+
expected_logprobs_count = num_items + 1
|
1818
|
+
if len(input_logprobs) != expected_logprobs_count:
|
1819
|
+
raise RuntimeError(
|
1820
|
+
f"Expected {expected_logprobs_count} input_token_ids_logprobs for multi-item scoring "
|
1821
|
+
f"with {num_items} items, but got {len(input_logprobs)}. "
|
1822
|
+
f"Request ID: {single_result['meta_info'].get('id', '<unknown>')}"
|
1823
|
+
)
|
1824
|
+
|
1825
|
+
# Skip the first delimiter (between query and first item) and process remaining delimiter positions
|
1826
|
+
# We want to exclude the first one since it represents the boundary between query and first item, not an item boundary
|
1827
|
+
start_idx = 1 if len(input_logprobs) > 1 else 0
|
1828
|
+
|
1829
|
+
# Process logprobs for each item position (excluding first delimiter)
|
1830
|
+
for item_idx in range(num_items):
|
1831
|
+
logprob_idx = start_idx + item_idx
|
1832
|
+
item_logprobs_data = input_logprobs[logprob_idx]
|
1833
|
+
logprobs = self._extract_logprobs_for_tokens(
|
1834
|
+
item_logprobs_data, label_token_ids
|
1835
|
+
)
|
1836
|
+
score_list = self._convert_logprobs_to_scores(
|
1837
|
+
logprobs, label_token_ids, apply_softmax
|
1838
|
+
)
|
1839
|
+
scores.append(score_list)
|
1840
|
+
|
1841
|
+
return scores
|
1842
|
+
|
1843
|
+
def _process_single_item_scoring_results(
|
1844
|
+
self, results: Any, label_token_ids: List[int], apply_softmax: bool
|
1845
|
+
) -> List[List[float]]:
|
1846
|
+
"""
|
1847
|
+
Process results from single-item scoring request.
|
1848
|
+
Single-item scoring results are stored in output_token_ids_logprobs.
|
1849
|
+
|
1850
|
+
Args:
|
1851
|
+
results: Results from generate_request
|
1852
|
+
label_token_ids: Token IDs to extract scores for
|
1853
|
+
apply_softmax: Whether to apply softmax normalization
|
1854
|
+
|
1855
|
+
Returns:
|
1856
|
+
List of score lists, one for each result
|
1857
|
+
"""
|
1858
|
+
scores = []
|
1859
|
+
|
1860
|
+
for result in results:
|
1861
|
+
# For single-item scoring, logprobs are in output_token_ids_logprobs
|
1862
|
+
output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
|
1863
|
+
|
1864
|
+
if not output_logprobs or len(output_logprobs) == 0:
|
1865
|
+
raise RuntimeError(
|
1866
|
+
f"output_logprobs is empty for request {result['meta_info'].get('id', '<unknown>')}."
|
1867
|
+
)
|
1868
|
+
|
1869
|
+
# Extract logprobs for the first (and only) position
|
1870
|
+
logprobs = self._extract_logprobs_for_tokens(
|
1871
|
+
output_logprobs[0], label_token_ids
|
1872
|
+
)
|
1873
|
+
score_list = self._convert_logprobs_to_scores(
|
1874
|
+
logprobs, label_token_ids, apply_softmax
|
1875
|
+
)
|
1876
|
+
scores.append(score_list)
|
1877
|
+
|
1878
|
+
return scores
|
1879
|
+
|
1989
1880
|
async def score_request(
|
1990
1881
|
self,
|
1991
1882
|
query: Optional[Union[str, List[int]]] = None,
|
@@ -1996,7 +1887,29 @@ class TokenizerManager:
|
|
1996
1887
|
request: Optional[Any] = None,
|
1997
1888
|
) -> List[List[float]]:
|
1998
1889
|
"""
|
1999
|
-
|
1890
|
+
Score the probability of specified token IDs appearing after the given (query + item) pair.
|
1891
|
+
|
1892
|
+
This method supports two scoring approaches:
|
1893
|
+
1. Single-Item scoring (default): Process each query+item pair independently
|
1894
|
+
2. Multi-Item scoring: When multi_item_scoring_delimiter is set, combine query and
|
1895
|
+
multiple items into a single sequence using delimiter for efficient processing.
|
1896
|
+
Note: item_first parameter is ignored in multi-item scoring mode since it uses
|
1897
|
+
a fixed format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter>
|
1898
|
+
|
1899
|
+
Multi-item scoring works with both text and pre-tokenized inputs:
|
1900
|
+
- Text: query<delimiter_text>item1<delimiter_text>item2<delimiter_text>item3<delimiter_text>
|
1901
|
+
- Tokens: query<delimiter_token_id>item1<delimiter_token_id>item2<delimiter_token_id>item3<delimiter_token_id>
|
1902
|
+
|
1903
|
+
Args:
|
1904
|
+
query: The query text or pre-tokenized query token IDs
|
1905
|
+
items: The item text(s) or pre-tokenized item token IDs
|
1906
|
+
label_token_ids: List of token IDs to compute probabilities for
|
1907
|
+
apply_softmax: Whether to normalize probabilities using softmax
|
1908
|
+
item_first: If True, prepend items to query. Ignored for multi-item scoring.
|
1909
|
+
request: Optional FastAPI request object
|
1910
|
+
|
1911
|
+
Returns:
|
1912
|
+
List of lists containing probabilities for each item and each label token
|
2000
1913
|
"""
|
2001
1914
|
if label_token_ids is None:
|
2002
1915
|
raise ValueError("label_token_ids must be provided")
|
@@ -2009,9 +1922,17 @@ class TokenizerManager:
|
|
2009
1922
|
f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})"
|
2010
1923
|
)
|
2011
1924
|
|
1925
|
+
# Check if multi-item scoring is enabled by presence of delimiter
|
1926
|
+
use_multi_item_scoring = (
|
1927
|
+
self.server_args.multi_item_scoring_delimiter is not None
|
1928
|
+
and self.multi_item_delimiter_text is not None
|
1929
|
+
)
|
1930
|
+
|
2012
1931
|
batch_request = GenerateReqInput(
|
2013
1932
|
token_ids_logprob=label_token_ids,
|
2014
1933
|
return_logprob=True,
|
1934
|
+
# Set logprob_start_len=0 for multi-item scoring since we want logprobs at all delimiter positions
|
1935
|
+
logprob_start_len=0 if use_multi_item_scoring else -1,
|
2015
1936
|
stream=False,
|
2016
1937
|
sampling_params={"max_new_tokens": 0},
|
2017
1938
|
)
|
@@ -2023,12 +1944,23 @@ class TokenizerManager:
|
|
2023
1944
|
):
|
2024
1945
|
# Both query and items are text
|
2025
1946
|
items_list = [items] if isinstance(items, str) else items
|
2026
|
-
if item_first:
|
2027
|
-
prompts = [f"{item}{query}" for item in items_list]
|
2028
|
-
else:
|
2029
|
-
prompts = [f"{query}{item}" for item in items_list]
|
2030
1947
|
|
2031
|
-
|
1948
|
+
if use_multi_item_scoring:
|
1949
|
+
# Multi-item scoring: create single prompt with delimiter text
|
1950
|
+
# Always use format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter>
|
1951
|
+
# (item_first is ignored for multi-item scoring)
|
1952
|
+
delimiter = self.multi_item_delimiter_text
|
1953
|
+
combined_items = delimiter.join(items_list)
|
1954
|
+
# Add final delimiter after the last item for logprob extraction
|
1955
|
+
single_prompt = f"{query}{delimiter}{combined_items}{delimiter}"
|
1956
|
+
batch_request.text = [single_prompt]
|
1957
|
+
else:
|
1958
|
+
# Single-item scoring: create separate prompts for each item
|
1959
|
+
if item_first:
|
1960
|
+
prompts = [f"{item}{query}" for item in items_list]
|
1961
|
+
else:
|
1962
|
+
prompts = [f"{query}{item}" for item in items_list]
|
1963
|
+
batch_request.text = prompts
|
2032
1964
|
|
2033
1965
|
elif (
|
2034
1966
|
isinstance(query, list)
|
@@ -2037,57 +1969,75 @@ class TokenizerManager:
|
|
2037
1969
|
and isinstance(items[0], list)
|
2038
1970
|
):
|
2039
1971
|
# Both query and items are token IDs
|
2040
|
-
if
|
2041
|
-
|
1972
|
+
if use_multi_item_scoring:
|
1973
|
+
# Multi-item scoring: concatenate with delimiter token ID
|
1974
|
+
# Format: query<delimiter_token_id>item1<delimiter_token_id>item2<delimiter_token_id>item3<delimiter_token_id>
|
1975
|
+
delimiter_token_id = self.server_args.multi_item_scoring_delimiter
|
1976
|
+
combined_input_ids = self._build_multi_item_token_sequence(
|
1977
|
+
query, items, delimiter_token_id
|
1978
|
+
)
|
1979
|
+
batch_request.input_ids = [combined_input_ids]
|
2042
1980
|
else:
|
2043
|
-
|
2044
|
-
|
2045
|
-
|
1981
|
+
# Single-item scoring: process each item separately
|
1982
|
+
if item_first:
|
1983
|
+
input_ids_list = [item + query for item in items]
|
1984
|
+
else:
|
1985
|
+
input_ids_list = [query + item for item in items]
|
1986
|
+
batch_request.input_ids = input_ids_list
|
2046
1987
|
else:
|
2047
1988
|
raise ValueError(
|
2048
1989
|
"Invalid combination of query/items types for score_request."
|
2049
1990
|
)
|
2050
1991
|
|
2051
1992
|
results = await self.generate_request(batch_request, request).__anext__()
|
2052
|
-
scores = []
|
2053
|
-
|
2054
|
-
for result in results:
|
2055
|
-
# Get logprobs for each token
|
2056
|
-
logprobs = {}
|
2057
|
-
|
2058
|
-
# For scoring requests, we read from output_token_ids_logprobs since we want
|
2059
|
-
# the logprobs for specific tokens mentioned in the label_token_ids at
|
2060
|
-
# the next position after the last token in the prompt
|
2061
|
-
output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
|
2062
|
-
|
2063
|
-
# Throw an error here if output_logprobs is None
|
2064
|
-
if output_logprobs is None:
|
2065
|
-
raise RuntimeError(
|
2066
|
-
f"output_logprobs is None for request {result['meta_info'].get('id', '<unknown>')}. "
|
2067
|
-
"This usually indicates a problem with the scoring request or the backend output."
|
2068
|
-
)
|
2069
|
-
|
2070
|
-
for logprob, token_id, _ in output_logprobs[0]:
|
2071
|
-
if token_id in label_token_ids:
|
2072
|
-
logprobs[token_id] = logprob
|
2073
1993
|
|
2074
|
-
|
2075
|
-
|
2076
|
-
|
2077
|
-
|
1994
|
+
if use_multi_item_scoring:
|
1995
|
+
# Multi-item scoring: extract scores from input_token_ids_logprobs
|
1996
|
+
return self._process_multi_item_scoring_results(
|
1997
|
+
results, items, label_token_ids, apply_softmax, batch_request
|
1998
|
+
)
|
1999
|
+
else:
|
2000
|
+
# Single-item scoring: process each result separately
|
2001
|
+
return self._process_single_item_scoring_results(
|
2002
|
+
results, label_token_ids, apply_softmax
|
2003
|
+
)
|
2078
2004
|
|
2079
|
-
|
2080
|
-
|
2081
|
-
|
2082
|
-
|
2083
|
-
|
2084
|
-
|
2085
|
-
|
2086
|
-
]
|
2005
|
+
async def watch_load_thread(self):
|
2006
|
+
# Only for dp_controller when dp_size > 1
|
2007
|
+
if (
|
2008
|
+
self.server_args.dp_size == 1
|
2009
|
+
or self.server_args.load_balance_method == "round_robin"
|
2010
|
+
):
|
2011
|
+
return
|
2087
2012
|
|
2088
|
-
|
2013
|
+
while True:
|
2014
|
+
await asyncio.sleep(self.server_args.load_watch_interval)
|
2015
|
+
loads = await self.get_load_communicator(GetLoadReqInput())
|
2016
|
+
load_udpate_req = WatchLoadUpdateReq(loads=loads)
|
2017
|
+
self.send_to_scheduler.send_pyobj(load_udpate_req)
|
2089
2018
|
|
2090
|
-
|
2019
|
+
def _trace_request_start(
|
2020
|
+
self,
|
2021
|
+
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
2022
|
+
created_time: Optional[float] = None,
|
2023
|
+
):
|
2024
|
+
if obj.is_single:
|
2025
|
+
bootstrap_room = (
|
2026
|
+
obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None
|
2027
|
+
)
|
2028
|
+
trace_req_start(obj.rid, bootstrap_room, ts=int(created_time * 1e9))
|
2029
|
+
trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True)
|
2030
|
+
else:
|
2031
|
+
for i in range(len(obj.rid)):
|
2032
|
+
bootstrap_room = (
|
2033
|
+
obj.bootstrap_room[i]
|
2034
|
+
if hasattr(obj, "bootstrap_room") and obj.bootstrap_room
|
2035
|
+
else None
|
2036
|
+
)
|
2037
|
+
trace_req_start(obj.rid[i], bootstrap_room, ts=int(created_time * 1e9))
|
2038
|
+
trace_slice_start(
|
2039
|
+
"", obj.rid[i], ts=int(created_time * 1e9), anonymous=True
|
2040
|
+
)
|
2091
2041
|
|
2092
2042
|
|
2093
2043
|
class ServerStatus(Enum):
|
@@ -2134,57 +2084,12 @@ class SignalHandler:
|
|
2134
2084
|
|
2135
2085
|
def running_phase_sigquit_handler(self, signum=None, frame=None):
|
2136
2086
|
logger.error(
|
2137
|
-
"
|
2087
|
+
f"SIGQUIT received. {signum=}, {frame=}. It usually means one child failed."
|
2138
2088
|
)
|
2139
2089
|
self.tokenizer_manager.dump_requests_before_crash()
|
2140
2090
|
kill_process_tree(os.getpid())
|
2141
2091
|
|
2142
2092
|
|
2143
|
-
T = TypeVar("T")
|
2144
|
-
|
2145
|
-
|
2146
|
-
class _Communicator(Generic[T]):
|
2147
|
-
"""Note: The communicator now only run up to 1 in-flight request at any time."""
|
2148
|
-
|
2149
|
-
enable_multi_tokenizer = False
|
2150
|
-
|
2151
|
-
def __init__(self, sender, fan_out: int):
|
2152
|
-
self._sender = sender
|
2153
|
-
self._fan_out = fan_out
|
2154
|
-
self._result_event: Optional[asyncio.Event] = None
|
2155
|
-
self._result_values: Optional[List[T]] = None
|
2156
|
-
self._ready_queue: Deque[asyncio.Future] = deque()
|
2157
|
-
|
2158
|
-
async def __call__(self, obj):
|
2159
|
-
ready_event = asyncio.Event()
|
2160
|
-
if self._result_event is not None or len(self._ready_queue) > 0:
|
2161
|
-
self._ready_queue.append(ready_event)
|
2162
|
-
await ready_event.wait()
|
2163
|
-
assert self._result_event is None
|
2164
|
-
assert self._result_values is None
|
2165
|
-
|
2166
|
-
if obj:
|
2167
|
-
if _Communicator.enable_multi_tokenizer:
|
2168
|
-
obj = MultiTokenizerWarpper(worker_id=os.getpid(), obj=obj)
|
2169
|
-
self._sender.send_pyobj(obj)
|
2170
|
-
|
2171
|
-
self._result_event = asyncio.Event()
|
2172
|
-
self._result_values = []
|
2173
|
-
await self._result_event.wait()
|
2174
|
-
result_values = self._result_values
|
2175
|
-
self._result_event = self._result_values = None
|
2176
|
-
|
2177
|
-
if len(self._ready_queue) > 0:
|
2178
|
-
self._ready_queue.popleft().set()
|
2179
|
-
|
2180
|
-
return result_values
|
2181
|
-
|
2182
|
-
def handle_recv(self, recv_obj: T):
|
2183
|
-
self._result_values.append(recv_obj)
|
2184
|
-
if len(self._result_values) == self._fan_out:
|
2185
|
-
self._result_event.set()
|
2186
|
-
|
2187
|
-
|
2188
2093
|
# Note: request abort handling logic
|
2189
2094
|
# We should handle all of the following cases correctly.
|
2190
2095
|
#
|