sglang 0.5.2rc1__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/interpreter.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +192 -113
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +132 -57
- sglang/srt/entrypoints/openai/protocol.py +115 -7
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +207 -58
- sglang/srt/entrypoints/openai/serving_completions.py +17 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +10 -4
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +49 -4
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +24 -1
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +106 -82
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +53 -7
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +225 -57
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +78 -49
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +215 -314
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +358 -404
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +147 -19
- sglang/srt/managers/scheduler.py +501 -304
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +119 -40
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +321 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +15 -21
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +58 -34
- sglang/srt/mem_cache/hiradix_cache.py +227 -80
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -223
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +268 -63
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +198 -30
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +519 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +55 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +98 -57
- sglang/srt/model_executor/model_runner.py +433 -158
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +833 -152
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +14 -5
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +124 -14
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +26 -5
- sglang/srt/models/qwen3_moe.py +71 -12
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +10 -3
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +6 -0
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1030 -254
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +253 -136
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +445 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +22 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/RECORD +392 -258
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -31,18 +31,7 @@ from contextlib import nullcontext
|
|
31
31
|
from datetime import datetime
|
32
32
|
from enum import Enum
|
33
33
|
from http import HTTPStatus
|
34
|
-
from typing import
|
35
|
-
Any,
|
36
|
-
Awaitable,
|
37
|
-
Deque,
|
38
|
-
Dict,
|
39
|
-
Generic,
|
40
|
-
List,
|
41
|
-
Optional,
|
42
|
-
Tuple,
|
43
|
-
TypeVar,
|
44
|
-
Union,
|
45
|
-
)
|
34
|
+
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
|
46
35
|
|
47
36
|
import fastapi
|
48
37
|
import torch
|
@@ -53,80 +42,49 @@ from fastapi import BackgroundTasks
|
|
53
42
|
|
54
43
|
from sglang.srt.aio_rwlock import RWLock
|
55
44
|
from sglang.srt.configs.model_config import ModelConfig
|
56
|
-
from sglang.srt.disaggregation.utils import
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
get_kv_class,
|
61
|
-
)
|
62
|
-
from sglang.srt.hf_transformers_utils import (
|
63
|
-
get_processor,
|
64
|
-
get_tokenizer,
|
65
|
-
get_tokenizer_from_processor,
|
66
|
-
)
|
67
|
-
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
|
45
|
+
from sglang.srt.disaggregation.utils import DisaggregationMode
|
46
|
+
from sglang.srt.lora.lora_registry import LoRARegistry
|
47
|
+
from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer
|
48
|
+
from sglang.srt.managers.disagg_service import start_disagg_service
|
68
49
|
from sglang.srt.managers.io_struct import (
|
69
50
|
AbortReq,
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
51
|
+
BatchEmbeddingOutput,
|
52
|
+
BatchMultimodalOutput,
|
53
|
+
BatchStrOutput,
|
54
|
+
BatchTokenIDOutput,
|
74
55
|
BatchTokenizedEmbeddingReqInput,
|
75
56
|
BatchTokenizedGenerateReqInput,
|
76
|
-
ClearHiCacheReqInput,
|
77
|
-
ClearHiCacheReqOutput,
|
78
|
-
CloseSessionReqInput,
|
79
57
|
ConfigureLoggingReq,
|
80
58
|
EmbeddingReqInput,
|
81
|
-
ExpertDistributionReq,
|
82
|
-
ExpertDistributionReqOutput,
|
83
|
-
FlushCacheReqInput,
|
84
|
-
FlushCacheReqOutput,
|
85
59
|
FreezeGCReq,
|
86
60
|
GenerateReqInput,
|
87
|
-
|
88
|
-
GetInternalStateReqOutput,
|
89
|
-
GetWeightsByNameReqInput,
|
90
|
-
GetWeightsByNameReqOutput,
|
61
|
+
GetLoadReqInput,
|
91
62
|
HealthCheckOutput,
|
92
|
-
|
93
|
-
InitWeightsUpdateGroupReqOutput,
|
94
|
-
LoadLoRAAdapterReqInput,
|
95
|
-
LoadLoRAAdapterReqOutput,
|
96
|
-
LoRAUpdateResult,
|
97
|
-
MultiTokenizerWarpper,
|
98
|
-
OpenSessionReqInput,
|
63
|
+
MultiTokenizerWrapper,
|
99
64
|
OpenSessionReqOutput,
|
100
|
-
ProfileReq,
|
101
|
-
ProfileReqOutput,
|
102
|
-
ProfileReqType,
|
103
|
-
ReleaseMemoryOccupationReqInput,
|
104
|
-
ReleaseMemoryOccupationReqOutput,
|
105
|
-
ResumeMemoryOccupationReqInput,
|
106
|
-
ResumeMemoryOccupationReqOutput,
|
107
65
|
SessionParams,
|
108
|
-
SetInternalStateReq,
|
109
|
-
SetInternalStateReqOutput,
|
110
|
-
SlowDownReqInput,
|
111
|
-
SlowDownReqOutput,
|
112
66
|
TokenizedEmbeddingReqInput,
|
113
67
|
TokenizedGenerateReqInput,
|
114
|
-
UnloadLoRAAdapterReqInput,
|
115
|
-
UnloadLoRAAdapterReqOutput,
|
116
68
|
UpdateWeightFromDiskReqInput,
|
117
69
|
UpdateWeightFromDiskReqOutput,
|
118
|
-
|
119
|
-
UpdateWeightsFromDistributedReqOutput,
|
120
|
-
UpdateWeightsFromTensorReqInput,
|
121
|
-
UpdateWeightsFromTensorReqOutput,
|
70
|
+
WatchLoadUpdateReq,
|
122
71
|
)
|
123
72
|
from sglang.srt.managers.mm_utils import TensorTransportMode
|
124
73
|
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
|
125
74
|
from sglang.srt.managers.scheduler import is_health_check_generate_req
|
126
75
|
from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
|
76
|
+
from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicatorMixin
|
127
77
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
128
78
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
129
79
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
80
|
+
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
81
|
+
from sglang.srt.tracing.trace import (
|
82
|
+
trace_get_proc_propagate_context,
|
83
|
+
trace_req_finish,
|
84
|
+
trace_req_start,
|
85
|
+
trace_slice_end,
|
86
|
+
trace_slice_start,
|
87
|
+
)
|
130
88
|
from sglang.srt.utils import (
|
131
89
|
configure_gc_warning,
|
132
90
|
dataclass_to_string_truncated,
|
@@ -136,6 +94,11 @@ from sglang.srt.utils import (
|
|
136
94
|
get_zmq_socket,
|
137
95
|
kill_process_tree,
|
138
96
|
)
|
97
|
+
from sglang.srt.utils.hf_transformers_utils import (
|
98
|
+
get_processor,
|
99
|
+
get_tokenizer,
|
100
|
+
get_tokenizer_from_processor,
|
101
|
+
)
|
139
102
|
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
140
103
|
|
141
104
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
@@ -180,7 +143,7 @@ class ReqState:
|
|
180
143
|
output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
|
181
144
|
|
182
145
|
|
183
|
-
class TokenizerManager:
|
146
|
+
class TokenizerManager(TokenizerCommunicatorMixin):
|
184
147
|
"""TokenizerManager is a process that tokenizes the text."""
|
185
148
|
|
186
149
|
def __init__(
|
@@ -199,6 +162,7 @@ class TokenizerManager:
|
|
199
162
|
else None
|
200
163
|
)
|
201
164
|
self.crash_dump_folder = server_args.crash_dump_folder
|
165
|
+
self.enable_trace = server_args.enable_trace
|
202
166
|
|
203
167
|
# Read model args
|
204
168
|
self.model_path = server_args.model_path
|
@@ -210,8 +174,17 @@ class TokenizerManager:
|
|
210
174
|
self.image_token_id = self.model_config.image_token_id
|
211
175
|
self.max_req_input_len = None # Will be set later in engine.py
|
212
176
|
|
177
|
+
speculative_algorithm = SpeculativeAlgorithm.from_string(
|
178
|
+
server_args.speculative_algorithm
|
179
|
+
)
|
180
|
+
self.reserve_input_token_num = (
|
181
|
+
0
|
182
|
+
if speculative_algorithm.is_none()
|
183
|
+
else server_args.speculative_num_draft_tokens
|
184
|
+
)
|
185
|
+
|
213
186
|
if self.model_config.is_multimodal:
|
214
|
-
import_processors()
|
187
|
+
import_processors("sglang.srt.multimodal.processors")
|
215
188
|
try:
|
216
189
|
_processor = get_processor(
|
217
190
|
server_args.tokenizer_path,
|
@@ -262,6 +235,18 @@ class TokenizerManager:
|
|
262
235
|
trust_remote_code=server_args.trust_remote_code,
|
263
236
|
revision=server_args.revision,
|
264
237
|
)
|
238
|
+
# Initialize async dynamic batch tokenizer if enabled (common for both multimodal and non-multimodal)
|
239
|
+
if (
|
240
|
+
server_args.enable_dynamic_batch_tokenizer
|
241
|
+
and not server_args.skip_tokenizer_init
|
242
|
+
):
|
243
|
+
self.async_dynamic_batch_tokenizer = AsyncDynamicbatchTokenizer(
|
244
|
+
self.tokenizer,
|
245
|
+
max_batch_size=server_args.dynamic_batch_tokenizer_batch_size,
|
246
|
+
batch_wait_timeout_s=server_args.dynamic_batch_tokenizer_batch_timeout,
|
247
|
+
)
|
248
|
+
else:
|
249
|
+
self.async_dynamic_batch_tokenizer = None
|
265
250
|
|
266
251
|
# Init inter-process communication
|
267
252
|
context = zmq.asyncio.Context(2)
|
@@ -319,8 +304,10 @@ class TokenizerManager:
|
|
319
304
|
# LoRA updates and inference to overlap.
|
320
305
|
self.lora_update_lock = asyncio.Lock()
|
321
306
|
|
322
|
-
|
323
|
-
|
307
|
+
self.disaggregation_mode = DisaggregationMode(
|
308
|
+
self.server_args.disaggregation_mode
|
309
|
+
)
|
310
|
+
self.bootstrap_server = start_disagg_service(self.server_args)
|
324
311
|
|
325
312
|
# For load balancing
|
326
313
|
self.current_load = 0
|
@@ -328,11 +315,16 @@ class TokenizerManager:
|
|
328
315
|
|
329
316
|
# Metrics
|
330
317
|
if self.enable_metrics:
|
318
|
+
labels = {
|
319
|
+
"model_name": self.server_args.served_model_name,
|
320
|
+
# TODO: Add lora name/path in the future,
|
321
|
+
}
|
322
|
+
if server_args.tokenizer_metrics_allowed_custom_labels:
|
323
|
+
for label in server_args.tokenizer_metrics_allowed_custom_labels:
|
324
|
+
labels[label] = ""
|
331
325
|
self.metrics_collector = TokenizerMetricsCollector(
|
332
|
-
|
333
|
-
|
334
|
-
# TODO: Add lora name/path in the future,
|
335
|
-
},
|
326
|
+
server_args=server_args,
|
327
|
+
labels=labels,
|
336
328
|
bucket_time_to_first_token=self.server_args.bucket_time_to_first_token,
|
337
329
|
bucket_e2e_request_latency=self.server_args.bucket_e2e_request_latency,
|
338
330
|
bucket_inter_token_latency=self.server_args.bucket_inter_token_latency,
|
@@ -343,58 +335,14 @@ class TokenizerManager:
|
|
343
335
|
if self.server_args.gc_warning_threshold_secs > 0.0:
|
344
336
|
configure_gc_warning(self.server_args.gc_warning_threshold_secs)
|
345
337
|
|
346
|
-
# Communicators
|
347
|
-
self.init_weights_update_group_communicator = _Communicator(
|
348
|
-
self.send_to_scheduler, server_args.dp_size
|
349
|
-
)
|
350
|
-
self.update_weights_from_distributed_communicator = _Communicator(
|
351
|
-
self.send_to_scheduler, server_args.dp_size
|
352
|
-
)
|
353
|
-
self.update_weights_from_tensor_communicator = _Communicator(
|
354
|
-
self.send_to_scheduler, server_args.dp_size
|
355
|
-
)
|
356
|
-
self.get_weights_by_name_communicator = _Communicator(
|
357
|
-
self.send_to_scheduler, server_args.dp_size
|
358
|
-
)
|
359
|
-
self.release_memory_occupation_communicator = _Communicator(
|
360
|
-
self.send_to_scheduler, server_args.dp_size
|
361
|
-
)
|
362
|
-
self.resume_memory_occupation_communicator = _Communicator(
|
363
|
-
self.send_to_scheduler, server_args.dp_size
|
364
|
-
)
|
365
|
-
self.slow_down_communicator = _Communicator(
|
366
|
-
self.send_to_scheduler, server_args.dp_size
|
367
|
-
)
|
368
|
-
self.flush_cache_communicator = _Communicator(
|
369
|
-
self.send_to_scheduler, server_args.dp_size
|
370
|
-
)
|
371
|
-
self.clear_hicache_storage_communicator = _Communicator(
|
372
|
-
self.send_to_scheduler, server_args.dp_size
|
373
|
-
)
|
374
|
-
self.profile_communicator = _Communicator(
|
375
|
-
self.send_to_scheduler, server_args.dp_size
|
376
|
-
)
|
377
|
-
self.get_internal_state_communicator = _Communicator(
|
378
|
-
self.send_to_scheduler, server_args.dp_size
|
379
|
-
)
|
380
|
-
self.set_internal_state_communicator = _Communicator(
|
381
|
-
self.send_to_scheduler, server_args.dp_size
|
382
|
-
)
|
383
|
-
self.expert_distribution_communicator = _Communicator(
|
384
|
-
self.send_to_scheduler, server_args.dp_size
|
385
|
-
)
|
386
|
-
self.update_lora_adapter_communicator = _Communicator(
|
387
|
-
self.send_to_scheduler, server_args.dp_size
|
388
|
-
)
|
389
|
-
|
390
338
|
self._result_dispatcher = TypeBasedDispatcher(
|
391
339
|
[
|
392
340
|
(
|
393
341
|
(
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
342
|
+
BatchStrOutput,
|
343
|
+
BatchEmbeddingOutput,
|
344
|
+
BatchTokenIDOutput,
|
345
|
+
BatchMultimodalOutput,
|
398
346
|
),
|
399
347
|
self._handle_batch_output,
|
400
348
|
),
|
@@ -404,100 +352,15 @@ class TokenizerManager:
|
|
404
352
|
UpdateWeightFromDiskReqOutput,
|
405
353
|
self._handle_update_weights_from_disk_req_output,
|
406
354
|
),
|
407
|
-
(
|
408
|
-
InitWeightsUpdateGroupReqOutput,
|
409
|
-
self.init_weights_update_group_communicator.handle_recv,
|
410
|
-
),
|
411
|
-
(
|
412
|
-
UpdateWeightsFromDistributedReqOutput,
|
413
|
-
self.update_weights_from_distributed_communicator.handle_recv,
|
414
|
-
),
|
415
|
-
(
|
416
|
-
UpdateWeightsFromTensorReqOutput,
|
417
|
-
self.update_weights_from_tensor_communicator.handle_recv,
|
418
|
-
),
|
419
|
-
(
|
420
|
-
GetWeightsByNameReqOutput,
|
421
|
-
self.get_weights_by_name_communicator.handle_recv,
|
422
|
-
),
|
423
|
-
(
|
424
|
-
ReleaseMemoryOccupationReqOutput,
|
425
|
-
self.release_memory_occupation_communicator.handle_recv,
|
426
|
-
),
|
427
|
-
(
|
428
|
-
ResumeMemoryOccupationReqOutput,
|
429
|
-
self.resume_memory_occupation_communicator.handle_recv,
|
430
|
-
),
|
431
|
-
(
|
432
|
-
SlowDownReqOutput,
|
433
|
-
self.slow_down_communicator.handle_recv,
|
434
|
-
),
|
435
|
-
(
|
436
|
-
ClearHiCacheReqOutput,
|
437
|
-
self.clear_hicache_storage_communicator.handle_recv,
|
438
|
-
),
|
439
|
-
(
|
440
|
-
FlushCacheReqOutput,
|
441
|
-
self.flush_cache_communicator.handle_recv,
|
442
|
-
),
|
443
|
-
(
|
444
|
-
ProfileReqOutput,
|
445
|
-
self.profile_communicator.handle_recv,
|
446
|
-
),
|
447
355
|
(
|
448
356
|
FreezeGCReq,
|
449
357
|
lambda x: None,
|
450
358
|
), # For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it.
|
451
|
-
(
|
452
|
-
GetInternalStateReqOutput,
|
453
|
-
self.get_internal_state_communicator.handle_recv,
|
454
|
-
),
|
455
|
-
(
|
456
|
-
SetInternalStateReqOutput,
|
457
|
-
self.set_internal_state_communicator.handle_recv,
|
458
|
-
),
|
459
|
-
(
|
460
|
-
ExpertDistributionReqOutput,
|
461
|
-
self.expert_distribution_communicator.handle_recv,
|
462
|
-
),
|
463
|
-
(
|
464
|
-
LoRAUpdateResult,
|
465
|
-
self.update_lora_adapter_communicator.handle_recv,
|
466
|
-
),
|
467
359
|
(HealthCheckOutput, lambda x: None),
|
468
360
|
]
|
469
361
|
)
|
470
362
|
|
471
|
-
|
472
|
-
self.disaggregation_mode = DisaggregationMode(
|
473
|
-
self.server_args.disaggregation_mode
|
474
|
-
)
|
475
|
-
self.disaggregation_transfer_backend = TransferBackend(
|
476
|
-
self.server_args.disaggregation_transfer_backend
|
477
|
-
)
|
478
|
-
# Start kv boostrap server on prefill
|
479
|
-
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
480
|
-
# only start bootstrap server on prefill tm
|
481
|
-
kv_bootstrap_server_class = get_kv_class(
|
482
|
-
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
483
|
-
)
|
484
|
-
self.bootstrap_server = kv_bootstrap_server_class(
|
485
|
-
self.server_args.disaggregation_bootstrap_port
|
486
|
-
)
|
487
|
-
is_create_store = (
|
488
|
-
self.server_args.node_rank == 0
|
489
|
-
and self.server_args.disaggregation_transfer_backend == "ascend"
|
490
|
-
)
|
491
|
-
if is_create_store:
|
492
|
-
try:
|
493
|
-
from mf_adapter import create_config_store
|
494
|
-
|
495
|
-
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
|
496
|
-
create_config_store(ascend_url)
|
497
|
-
except Exception as e:
|
498
|
-
error_message = f"Failed create mf store, invalid ascend_url."
|
499
|
-
error_message += f" With exception {e}"
|
500
|
-
raise error_message
|
363
|
+
self.init_communicators(server_args)
|
501
364
|
|
502
365
|
async def generate_request(
|
503
366
|
self,
|
@@ -517,6 +380,9 @@ class TokenizerManager:
|
|
517
380
|
# If it's a single value, add worker_id prefix
|
518
381
|
obj.rid = f"{self.worker_id}_{obj.rid}"
|
519
382
|
|
383
|
+
if self.enable_trace:
|
384
|
+
self._trace_request_start(obj, created_time)
|
385
|
+
|
520
386
|
if self.log_requests:
|
521
387
|
max_length, skip_names, _ = self.log_request_metadata
|
522
388
|
logger.info(
|
@@ -542,6 +408,144 @@ class TokenizerManager:
|
|
542
408
|
):
|
543
409
|
yield response
|
544
410
|
|
411
|
+
def _detect_input_format(
|
412
|
+
self, texts: Union[str, List[str]], is_cross_encoder: bool
|
413
|
+
) -> str:
|
414
|
+
"""Detect the format of input texts for proper tokenization handling.
|
415
|
+
|
416
|
+
Returns:
|
417
|
+
- "single_string": Regular single text like "Hello world"
|
418
|
+
- "batch_strings": Regular batch like ["Hello", "World"]
|
419
|
+
- "cross_encoder_pairs": Cross-encoder pairs like [["query", "document"]]
|
420
|
+
"""
|
421
|
+
if isinstance(texts, str):
|
422
|
+
return "single_string"
|
423
|
+
|
424
|
+
if (
|
425
|
+
is_cross_encoder
|
426
|
+
and len(texts) > 0
|
427
|
+
and isinstance(texts[0], list)
|
428
|
+
and len(texts[0]) == 2
|
429
|
+
):
|
430
|
+
return "cross_encoder_pairs"
|
431
|
+
|
432
|
+
return "batch_strings"
|
433
|
+
|
434
|
+
def _prepare_tokenizer_input(
|
435
|
+
self, texts: Union[str, List[str]], input_format: str
|
436
|
+
) -> Union[List[str], List[List[str]]]:
|
437
|
+
"""Prepare input for the tokenizer based on detected format."""
|
438
|
+
if input_format == "single_string":
|
439
|
+
return [texts] # Wrap single string for batch processing
|
440
|
+
elif input_format == "cross_encoder_pairs":
|
441
|
+
return texts # Already in correct format: [["query", "doc"]]
|
442
|
+
else: # batch_strings
|
443
|
+
return texts # Already in correct format: ["text1", "text2"]
|
444
|
+
|
445
|
+
def _extract_tokenizer_results(
|
446
|
+
self,
|
447
|
+
input_ids: List[List[int]],
|
448
|
+
token_type_ids: Optional[List[List[int]]],
|
449
|
+
input_format: str,
|
450
|
+
original_batch_size: int,
|
451
|
+
) -> Union[
|
452
|
+
Tuple[List[int], Optional[List[int]]],
|
453
|
+
Tuple[List[List[int]], Optional[List[List[int]]]],
|
454
|
+
]:
|
455
|
+
"""Extract results from tokenizer output based on input format."""
|
456
|
+
|
457
|
+
# For single inputs (string or single cross-encoder pair), extract first element
|
458
|
+
if (
|
459
|
+
input_format in ["single_string", "cross_encoder_pairs"]
|
460
|
+
and original_batch_size == 1
|
461
|
+
):
|
462
|
+
single_input_ids = input_ids[0] if input_ids else []
|
463
|
+
single_token_type_ids = token_type_ids[0] if token_type_ids else None
|
464
|
+
return single_input_ids, single_token_type_ids
|
465
|
+
|
466
|
+
# For true batches, return as-is
|
467
|
+
return input_ids, token_type_ids
|
468
|
+
|
469
|
+
async def _tokenize_texts(
|
470
|
+
self, texts: Union[str, List[str]], is_cross_encoder: bool = False
|
471
|
+
) -> Union[
|
472
|
+
Tuple[List[int], Optional[List[int]]],
|
473
|
+
Tuple[List[List[int]], Optional[List[List[int]]]],
|
474
|
+
]:
|
475
|
+
"""
|
476
|
+
Tokenize text(s) using the appropriate tokenizer strategy.
|
477
|
+
|
478
|
+
This method handles multiple input formats and chooses between async dynamic
|
479
|
+
batch tokenizer (for single texts only) and regular tokenizer.
|
480
|
+
|
481
|
+
Args:
|
482
|
+
texts: Text input in various formats:
|
483
|
+
|
484
|
+
Regular cases:
|
485
|
+
- Single string: "How are you?"
|
486
|
+
- Batch of strings: ["Hello", "World", "How are you?"]
|
487
|
+
|
488
|
+
Cross-encoder cases (sentence pairs for similarity/ranking):
|
489
|
+
- Single pair: [["query text", "document text"]]
|
490
|
+
- Multiple pairs: [["q1", "d1"], ["q2", "d2"], ["q3", "d3"]]
|
491
|
+
|
492
|
+
is_cross_encoder: Whether to return token_type_ids for cross-encoder models.
|
493
|
+
Enables proper handling of sentence pairs with segment IDs.
|
494
|
+
|
495
|
+
Returns:
|
496
|
+
Single input cases:
|
497
|
+
Tuple[List[int], Optional[List[int]]]: (input_ids, token_type_ids)
|
498
|
+
Example: ([101, 2129, 102], [0, 0, 0]) for single text
|
499
|
+
Example: ([101, 2129, 102, 4068, 102], [0, 0, 0, 1, 1]) for cross-encoder pair
|
500
|
+
|
501
|
+
Batch input cases:
|
502
|
+
Tuple[List[List[int]], Optional[List[List[int]]]]: (batch_input_ids, batch_token_type_ids)
|
503
|
+
Example: ([[101, 2129, 102], [101, 4068, 102]], None) for regular batch
|
504
|
+
|
505
|
+
Note: token_type_ids is None unless is_cross_encoder=True.
|
506
|
+
"""
|
507
|
+
if not texts or self.tokenizer is None:
|
508
|
+
raise ValueError("texts cannot be empty and tokenizer must be initialized")
|
509
|
+
|
510
|
+
# Step 1: Detect input format and prepare for tokenization
|
511
|
+
input_format = self._detect_input_format(texts, is_cross_encoder)
|
512
|
+
tokenizer_input = self._prepare_tokenizer_input(texts, input_format)
|
513
|
+
original_batch_size = len(texts) if not isinstance(texts, str) else 1
|
514
|
+
|
515
|
+
# Step 2: Set up tokenizer arguments
|
516
|
+
tokenizer_kwargs = (
|
517
|
+
{"return_token_type_ids": is_cross_encoder} if is_cross_encoder else {}
|
518
|
+
)
|
519
|
+
|
520
|
+
# Step 3: Choose tokenization strategy
|
521
|
+
use_async_tokenizer = (
|
522
|
+
self.async_dynamic_batch_tokenizer is not None
|
523
|
+
and input_format == "single_string"
|
524
|
+
)
|
525
|
+
|
526
|
+
if use_async_tokenizer:
|
527
|
+
logger.debug("Using async dynamic batch tokenizer for single text")
|
528
|
+
result = await self.async_dynamic_batch_tokenizer.encode(
|
529
|
+
tokenizer_input[0], **tokenizer_kwargs
|
530
|
+
)
|
531
|
+
# Convert to batch format for consistency
|
532
|
+
input_ids = [result["input_ids"]]
|
533
|
+
token_type_ids = (
|
534
|
+
[result["token_type_ids"]]
|
535
|
+
if is_cross_encoder and result.get("token_type_ids")
|
536
|
+
else None
|
537
|
+
)
|
538
|
+
else:
|
539
|
+
logger.debug(f"Using regular tokenizer for {len(tokenizer_input)} inputs")
|
540
|
+
encoded = self.tokenizer(tokenizer_input, **tokenizer_kwargs)
|
541
|
+
input_ids = encoded["input_ids"]
|
542
|
+
token_type_ids = encoded.get("token_type_ids") if is_cross_encoder else None
|
543
|
+
|
544
|
+
# Step 4: Extract results based on input format
|
545
|
+
return self._extract_tokenizer_results(
|
546
|
+
input_ids, token_type_ids, input_format, original_batch_size
|
547
|
+
)
|
548
|
+
|
545
549
|
async def _tokenize_one_request(
|
546
550
|
self,
|
547
551
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
@@ -572,14 +576,10 @@ class TokenizerManager:
|
|
572
576
|
"accept text prompts. Please provide input_ids or re-initialize "
|
573
577
|
"the engine with skip_tokenizer_init=False."
|
574
578
|
)
|
575
|
-
encoded = self.tokenizer(
|
576
|
-
input_text, return_token_type_ids=is_cross_encoder_request
|
577
|
-
)
|
578
579
|
|
579
|
-
input_ids =
|
580
|
-
|
581
|
-
|
582
|
-
token_type_ids = encoded.get("token_type_ids", [None])[0]
|
580
|
+
input_ids, token_type_ids = await self._tokenize_texts(
|
581
|
+
input_text, is_cross_encoder_request
|
582
|
+
)
|
583
583
|
|
584
584
|
if self.mm_processor and obj.contains_mm_input():
|
585
585
|
if not isinstance(obj.image_data, list):
|
@@ -599,6 +599,7 @@ class TokenizerManager:
|
|
599
599
|
mm_inputs = None
|
600
600
|
|
601
601
|
self._validate_one_request(obj, input_ids)
|
602
|
+
trace_slice_end("tokenize", obj.rid)
|
602
603
|
return self._create_tokenized_object(
|
603
604
|
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
|
604
605
|
)
|
@@ -611,6 +612,7 @@ class TokenizerManager:
|
|
611
612
|
_max_req_len = self.context_len
|
612
613
|
|
613
614
|
input_token_num = len(input_ids) if input_ids is not None else 0
|
615
|
+
input_token_num += self.reserve_input_token_num
|
614
616
|
if input_token_num >= self.context_len:
|
615
617
|
if self.server_args.allow_auto_truncate:
|
616
618
|
logger.warning(
|
@@ -673,7 +675,7 @@ class TokenizerManager:
|
|
673
675
|
):
|
674
676
|
raise ValueError(
|
675
677
|
"The server is not configured to enable custom logit processor. "
|
676
|
-
"Please set `--enable-custom-
|
678
|
+
"Please set `--enable-custom-logit-processor` to enable this feature."
|
677
679
|
)
|
678
680
|
|
679
681
|
def _validate_input_ids_in_vocab(
|
@@ -712,7 +714,6 @@ class TokenizerManager:
|
|
712
714
|
)
|
713
715
|
|
714
716
|
tokenized_obj = TokenizedGenerateReqInput(
|
715
|
-
obj.rid,
|
716
717
|
input_text,
|
717
718
|
input_ids,
|
718
719
|
mm_inputs,
|
@@ -722,6 +723,7 @@ class TokenizerManager:
|
|
722
723
|
obj.top_logprobs_num,
|
723
724
|
obj.token_ids_logprob,
|
724
725
|
obj.stream,
|
726
|
+
rid=obj.rid,
|
725
727
|
bootstrap_host=obj.bootstrap_host,
|
726
728
|
bootstrap_port=obj.bootstrap_port,
|
727
729
|
bootstrap_room=obj.bootstrap_room,
|
@@ -731,15 +733,18 @@ class TokenizerManager:
|
|
731
733
|
custom_logit_processor=obj.custom_logit_processor,
|
732
734
|
return_hidden_states=obj.return_hidden_states,
|
733
735
|
data_parallel_rank=obj.data_parallel_rank,
|
736
|
+
priority=obj.priority,
|
737
|
+
extra_key=obj.extra_key,
|
734
738
|
)
|
735
739
|
elif isinstance(obj, EmbeddingReqInput):
|
736
740
|
tokenized_obj = TokenizedEmbeddingReqInput(
|
737
|
-
obj.rid,
|
738
741
|
input_text,
|
739
742
|
input_ids,
|
740
743
|
mm_inputs,
|
741
744
|
token_type_ids,
|
742
745
|
sampling_params,
|
746
|
+
rid=obj.rid,
|
747
|
+
priority=obj.priority,
|
743
748
|
)
|
744
749
|
|
745
750
|
return tokenized_obj
|
@@ -754,19 +759,30 @@ class TokenizerManager:
|
|
754
759
|
requests = [obj[i] for i in range(batch_size)]
|
755
760
|
texts = [req.text for req in requests]
|
756
761
|
|
757
|
-
#
|
758
|
-
|
759
|
-
|
762
|
+
# Check if any request is a cross-encoder request
|
763
|
+
is_cross_encoder_request = any(
|
764
|
+
isinstance(req, EmbeddingReqInput) and req.is_cross_encoder_request
|
765
|
+
for req in requests
|
766
|
+
)
|
767
|
+
|
768
|
+
# Batch tokenize all texts using unified method
|
769
|
+
input_ids_list, token_type_ids_list = await self._tokenize_texts(
|
770
|
+
texts, is_cross_encoder_request
|
771
|
+
)
|
760
772
|
|
761
773
|
# Process all requests
|
762
774
|
tokenized_objs = []
|
763
775
|
for i, req in enumerate(requests):
|
764
776
|
self._validate_one_request(obj[i], input_ids_list[i])
|
777
|
+
token_type_ids = (
|
778
|
+
token_type_ids_list[i] if token_type_ids_list is not None else None
|
779
|
+
)
|
765
780
|
tokenized_objs.append(
|
766
781
|
self._create_tokenized_object(
|
767
|
-
req, req.text, input_ids_list[i], None, None
|
782
|
+
req, req.text, input_ids_list[i], None, None, token_type_ids
|
768
783
|
)
|
769
784
|
)
|
785
|
+
trace_slice_end("tokenize", req.rid)
|
770
786
|
logger.debug(f"Completed batch processing for {batch_size} requests")
|
771
787
|
return tokenized_objs
|
772
788
|
|
@@ -794,9 +810,12 @@ class TokenizerManager:
|
|
794
810
|
tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
|
795
811
|
created_time: Optional[float] = None,
|
796
812
|
):
|
813
|
+
trace_slice_start("dispatch", obj.rid)
|
814
|
+
tokenized_obj.trace_context = trace_get_proc_propagate_context(obj.rid)
|
797
815
|
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
798
816
|
state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
|
799
817
|
self.rid_to_state[obj.rid] = state
|
818
|
+
trace_slice_end("dispatch", obj.rid, thread_finish_flag=True)
|
800
819
|
return state
|
801
820
|
|
802
821
|
def _send_batch_request(
|
@@ -1014,73 +1033,16 @@ class TokenizerManager:
|
|
1014
1033
|
except StopAsyncIteration:
|
1015
1034
|
pass
|
1016
1035
|
|
1017
|
-
async def flush_cache(self) -> FlushCacheReqOutput:
|
1018
|
-
return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
|
1019
|
-
|
1020
|
-
async def clear_hicache_storage(self) -> ClearHiCacheReqOutput:
|
1021
|
-
"""Clear the hierarchical cache storage."""
|
1022
|
-
# Delegate to the scheduler to handle HiCacheStorage clearing
|
1023
|
-
return (await self.clear_hicache_storage_communicator(ClearHiCacheReqInput()))[
|
1024
|
-
0
|
1025
|
-
]
|
1026
|
-
|
1027
1036
|
def abort_request(self, rid: str = "", abort_all: bool = False):
|
1028
1037
|
if not abort_all and rid not in self.rid_to_state:
|
1029
1038
|
return
|
1030
|
-
req = AbortReq(rid, abort_all)
|
1039
|
+
req = AbortReq(rid=rid, abort_all=abort_all)
|
1031
1040
|
self.send_to_scheduler.send_pyobj(req)
|
1032
|
-
|
1033
1041
|
if self.enable_metrics:
|
1034
|
-
|
1035
|
-
|
1036
|
-
|
1037
|
-
|
1038
|
-
output_dir: Optional[str] = None,
|
1039
|
-
start_step: Optional[int] = None,
|
1040
|
-
num_steps: Optional[int] = None,
|
1041
|
-
activities: Optional[List[str]] = None,
|
1042
|
-
with_stack: Optional[bool] = None,
|
1043
|
-
record_shapes: Optional[bool] = None,
|
1044
|
-
profile_by_stage: bool = False,
|
1045
|
-
):
|
1046
|
-
self.auto_create_handle_loop()
|
1047
|
-
env_with_stack: bool = get_bool_env_var("SGLANG_PROFILE_WITH_STACK", "true")
|
1048
|
-
with_stack = False if with_stack is False or env_with_stack is False else True
|
1049
|
-
req = ProfileReq(
|
1050
|
-
type=ProfileReqType.START_PROFILE,
|
1051
|
-
output_dir=output_dir,
|
1052
|
-
start_step=start_step,
|
1053
|
-
num_steps=num_steps,
|
1054
|
-
activities=activities,
|
1055
|
-
with_stack=with_stack,
|
1056
|
-
record_shapes=record_shapes,
|
1057
|
-
profile_by_stage=profile_by_stage,
|
1058
|
-
profile_id=str(time.time()),
|
1059
|
-
)
|
1060
|
-
return await self._execute_profile(req)
|
1061
|
-
|
1062
|
-
async def stop_profile(self):
|
1063
|
-
self.auto_create_handle_loop()
|
1064
|
-
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
|
1065
|
-
return await self._execute_profile(req)
|
1066
|
-
|
1067
|
-
async def _execute_profile(self, req: ProfileReq):
|
1068
|
-
result = (await self.profile_communicator(req))[0]
|
1069
|
-
if not result.success:
|
1070
|
-
raise RuntimeError(result.message)
|
1071
|
-
return result
|
1072
|
-
|
1073
|
-
async def start_expert_distribution_record(self):
|
1074
|
-
self.auto_create_handle_loop()
|
1075
|
-
await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
|
1076
|
-
|
1077
|
-
async def stop_expert_distribution_record(self):
|
1078
|
-
self.auto_create_handle_loop()
|
1079
|
-
await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
|
1080
|
-
|
1081
|
-
async def dump_expert_distribution_record(self):
|
1082
|
-
self.auto_create_handle_loop()
|
1083
|
-
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
|
1042
|
+
# TODO: also use custom_labels from the request
|
1043
|
+
self.metrics_collector.observe_one_aborted_request(
|
1044
|
+
self.metrics_collector.labels
|
1045
|
+
)
|
1084
1046
|
|
1085
1047
|
async def pause_generation(self):
|
1086
1048
|
async with self.is_pause_cond:
|
@@ -1117,7 +1079,7 @@ class TokenizerManager:
|
|
1117
1079
|
self, obj: UpdateWeightFromDiskReqInput
|
1118
1080
|
) -> Tuple[bool, str]:
|
1119
1081
|
if self.server_args.tokenizer_worker_num > 1:
|
1120
|
-
obj =
|
1082
|
+
obj = MultiTokenizerWrapper(self.worker_id, obj)
|
1121
1083
|
self.send_to_scheduler.send_pyobj(obj)
|
1122
1084
|
self.model_update_result = asyncio.Future()
|
1123
1085
|
if self.server_args.dp_size == 1:
|
@@ -1142,291 +1104,6 @@ class TokenizerManager:
|
|
1142
1104
|
all_paused_requests = [r.num_paused_requests for r in result]
|
1143
1105
|
return all_success, all_message, all_paused_requests
|
1144
1106
|
|
1145
|
-
async def init_weights_update_group(
|
1146
|
-
self,
|
1147
|
-
obj: InitWeightsUpdateGroupReqInput,
|
1148
|
-
request: Optional[fastapi.Request] = None,
|
1149
|
-
) -> Tuple[bool, str]:
|
1150
|
-
self.auto_create_handle_loop()
|
1151
|
-
assert (
|
1152
|
-
self.server_args.dp_size == 1
|
1153
|
-
), "dp_size must be 1 for init parameter update group"
|
1154
|
-
result = (await self.init_weights_update_group_communicator(obj))[0]
|
1155
|
-
return result.success, result.message
|
1156
|
-
|
1157
|
-
async def update_weights_from_distributed(
|
1158
|
-
self,
|
1159
|
-
obj: UpdateWeightsFromDistributedReqInput,
|
1160
|
-
request: Optional[fastapi.Request] = None,
|
1161
|
-
) -> Tuple[bool, str]:
|
1162
|
-
self.auto_create_handle_loop()
|
1163
|
-
assert (
|
1164
|
-
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
1165
|
-
), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
|
1166
|
-
|
1167
|
-
if obj.abort_all_requests:
|
1168
|
-
self.abort_request(abort_all=True)
|
1169
|
-
|
1170
|
-
# This means that weight sync
|
1171
|
-
# cannot run while requests are in progress.
|
1172
|
-
async with self.model_update_lock.writer_lock:
|
1173
|
-
result = (await self.update_weights_from_distributed_communicator(obj))[0]
|
1174
|
-
return result.success, result.message
|
1175
|
-
|
1176
|
-
async def update_weights_from_tensor(
|
1177
|
-
self,
|
1178
|
-
obj: UpdateWeightsFromTensorReqInput,
|
1179
|
-
request: Optional[fastapi.Request] = None,
|
1180
|
-
) -> Tuple[bool, str]:
|
1181
|
-
self.auto_create_handle_loop()
|
1182
|
-
assert (
|
1183
|
-
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
1184
|
-
), "dp_size must be 1 or dp attention must be enabled for update weights from tensor"
|
1185
|
-
|
1186
|
-
if obj.abort_all_requests:
|
1187
|
-
self.abort_request(abort_all=True)
|
1188
|
-
|
1189
|
-
# This means that weight sync
|
1190
|
-
# cannot run while requests are in progress.
|
1191
|
-
async with self.model_update_lock.writer_lock:
|
1192
|
-
result = (await self.update_weights_from_tensor_communicator(obj))[0]
|
1193
|
-
return result.success, result.message
|
1194
|
-
|
1195
|
-
async def load_lora_adapter(
|
1196
|
-
self,
|
1197
|
-
obj: LoadLoRAAdapterReqInput,
|
1198
|
-
_: Optional[fastapi.Request] = None,
|
1199
|
-
) -> LoadLoRAAdapterReqOutput:
|
1200
|
-
self.auto_create_handle_loop()
|
1201
|
-
|
1202
|
-
try:
|
1203
|
-
if not self.server_args.enable_lora:
|
1204
|
-
raise ValueError(
|
1205
|
-
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
1206
|
-
)
|
1207
|
-
|
1208
|
-
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
1209
|
-
# with dp_size > 1.
|
1210
|
-
assert (
|
1211
|
-
self.server_args.dp_size == 1
|
1212
|
-
), "dp_size must be 1 for dynamic lora loading"
|
1213
|
-
logger.info(
|
1214
|
-
"Start load Lora adapter. Lora name=%s, path=%s",
|
1215
|
-
obj.lora_name,
|
1216
|
-
obj.lora_path,
|
1217
|
-
)
|
1218
|
-
|
1219
|
-
async with self.lora_update_lock:
|
1220
|
-
if (
|
1221
|
-
self.server_args.max_loaded_loras is not None
|
1222
|
-
and self.lora_registry.num_registered_loras
|
1223
|
-
>= self.server_args.max_loaded_loras
|
1224
|
-
):
|
1225
|
-
raise ValueError(
|
1226
|
-
f"Cannot load LoRA adapter {obj.lora_name} at path {obj.lora_path}. "
|
1227
|
-
f"Maximum number of loaded LoRA adapters is {self.server_args.max_loaded_loras}. "
|
1228
|
-
"Please unload some LoRA adapters before loading new ones."
|
1229
|
-
)
|
1230
|
-
|
1231
|
-
# Generate new uniquely identifiable LoRARef object.
|
1232
|
-
new_adapter = LoRARef(
|
1233
|
-
lora_name=obj.lora_name,
|
1234
|
-
lora_path=obj.lora_path,
|
1235
|
-
pinned=obj.pinned,
|
1236
|
-
)
|
1237
|
-
|
1238
|
-
# Trigger the actual loading operation at the backend processes.
|
1239
|
-
obj.lora_id = new_adapter.lora_id
|
1240
|
-
result = (await self.update_lora_adapter_communicator(obj))[0]
|
1241
|
-
|
1242
|
-
# Register the LoRA adapter only after loading is successful.
|
1243
|
-
if result.success:
|
1244
|
-
await self.lora_registry.register(new_adapter)
|
1245
|
-
|
1246
|
-
return result
|
1247
|
-
except ValueError as e:
|
1248
|
-
return LoadLoRAAdapterReqOutput(
|
1249
|
-
success=False,
|
1250
|
-
error_message=str(e),
|
1251
|
-
)
|
1252
|
-
|
1253
|
-
async def unload_lora_adapter(
|
1254
|
-
self,
|
1255
|
-
obj: UnloadLoRAAdapterReqInput,
|
1256
|
-
_: Optional[fastapi.Request] = None,
|
1257
|
-
) -> UnloadLoRAAdapterReqOutput:
|
1258
|
-
self.auto_create_handle_loop()
|
1259
|
-
|
1260
|
-
try:
|
1261
|
-
if not self.server_args.enable_lora:
|
1262
|
-
raise ValueError(
|
1263
|
-
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
1264
|
-
)
|
1265
|
-
|
1266
|
-
assert (
|
1267
|
-
obj.lora_name is not None
|
1268
|
-
), "lora_name must be provided to unload LoRA adapter"
|
1269
|
-
|
1270
|
-
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
1271
|
-
# with dp_size > 1.
|
1272
|
-
assert (
|
1273
|
-
self.server_args.dp_size == 1
|
1274
|
-
), "dp_size must be 1 for dynamic lora loading"
|
1275
|
-
logger.info(
|
1276
|
-
"Start unload Lora adapter. Lora name=%s",
|
1277
|
-
obj.lora_name,
|
1278
|
-
)
|
1279
|
-
|
1280
|
-
async with self.lora_update_lock:
|
1281
|
-
# Unregister the LoRA adapter from the registry to stop new requests for this adapter
|
1282
|
-
# from being started.
|
1283
|
-
lora_id = await self.lora_registry.unregister(obj.lora_name)
|
1284
|
-
obj.lora_id = lora_id
|
1285
|
-
|
1286
|
-
# Initiate the actual unloading operation at the backend processes only after all
|
1287
|
-
# ongoing requests using this LoRA adapter are finished.
|
1288
|
-
await self.lora_registry.wait_for_unload(lora_id)
|
1289
|
-
result = (await self.update_lora_adapter_communicator(obj))[0]
|
1290
|
-
|
1291
|
-
return result
|
1292
|
-
except ValueError as e:
|
1293
|
-
return UnloadLoRAAdapterReqOutput(success=False, error_message=str(e))
|
1294
|
-
|
1295
|
-
async def get_weights_by_name(
|
1296
|
-
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
1297
|
-
):
|
1298
|
-
self.auto_create_handle_loop()
|
1299
|
-
results = await self.get_weights_by_name_communicator(obj)
|
1300
|
-
all_parameters = [r.parameter for r in results]
|
1301
|
-
if self.server_args.dp_size == 1:
|
1302
|
-
return all_parameters[0]
|
1303
|
-
else:
|
1304
|
-
return all_parameters
|
1305
|
-
|
1306
|
-
async def release_memory_occupation(
|
1307
|
-
self,
|
1308
|
-
obj: ReleaseMemoryOccupationReqInput,
|
1309
|
-
request: Optional[fastapi.Request] = None,
|
1310
|
-
):
|
1311
|
-
self.auto_create_handle_loop()
|
1312
|
-
await self.release_memory_occupation_communicator(obj)
|
1313
|
-
|
1314
|
-
async def resume_memory_occupation(
|
1315
|
-
self,
|
1316
|
-
obj: ResumeMemoryOccupationReqInput,
|
1317
|
-
request: Optional[fastapi.Request] = None,
|
1318
|
-
):
|
1319
|
-
self.auto_create_handle_loop()
|
1320
|
-
await self.resume_memory_occupation_communicator(obj)
|
1321
|
-
|
1322
|
-
async def slow_down(
|
1323
|
-
self,
|
1324
|
-
obj: SlowDownReqInput,
|
1325
|
-
request: Optional[fastapi.Request] = None,
|
1326
|
-
):
|
1327
|
-
self.auto_create_handle_loop()
|
1328
|
-
await self.slow_down_communicator(obj)
|
1329
|
-
|
1330
|
-
async def open_session(
|
1331
|
-
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
1332
|
-
):
|
1333
|
-
self.auto_create_handle_loop()
|
1334
|
-
|
1335
|
-
if obj.session_id is None:
|
1336
|
-
obj.session_id = uuid.uuid4().hex
|
1337
|
-
elif obj.session_id in self.session_futures:
|
1338
|
-
return None
|
1339
|
-
|
1340
|
-
if self.server_args.tokenizer_worker_num > 1:
|
1341
|
-
obj = MultiTokenizerWarpper(self.worker_id, obj)
|
1342
|
-
self.send_to_scheduler.send_pyobj(obj)
|
1343
|
-
|
1344
|
-
self.session_futures[obj.session_id] = asyncio.Future()
|
1345
|
-
session_id = await self.session_futures[obj.session_id]
|
1346
|
-
del self.session_futures[obj.session_id]
|
1347
|
-
return session_id
|
1348
|
-
|
1349
|
-
async def close_session(
|
1350
|
-
self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
|
1351
|
-
):
|
1352
|
-
await self.send_to_scheduler.send_pyobj(obj)
|
1353
|
-
|
1354
|
-
async def get_internal_state(self) -> List[Dict[Any, Any]]:
|
1355
|
-
req = GetInternalStateReq()
|
1356
|
-
responses: List[GetInternalStateReqOutput] = (
|
1357
|
-
await self.get_internal_state_communicator(req)
|
1358
|
-
)
|
1359
|
-
# Many DP ranks
|
1360
|
-
return [res.internal_state for res in responses]
|
1361
|
-
|
1362
|
-
async def set_internal_state(self, obj: SetInternalStateReq) -> List[bool]:
|
1363
|
-
responses: List[SetInternalStateReqOutput] = (
|
1364
|
-
await self.set_internal_state_communicator(obj)
|
1365
|
-
)
|
1366
|
-
return [res.updated for res in responses]
|
1367
|
-
|
1368
|
-
async def get_load(self) -> dict:
|
1369
|
-
# TODO(lsyin): fake load report server
|
1370
|
-
if not self.current_load_lock.locked():
|
1371
|
-
async with self.current_load_lock:
|
1372
|
-
internal_state = await self.get_internal_state()
|
1373
|
-
self.current_load = internal_state[0]["load"]
|
1374
|
-
return {"load": self.current_load}
|
1375
|
-
|
1376
|
-
def get_log_request_metadata(self):
|
1377
|
-
max_length = None
|
1378
|
-
skip_names = None
|
1379
|
-
out_skip_names = None
|
1380
|
-
if self.log_requests:
|
1381
|
-
if self.log_requests_level == 0:
|
1382
|
-
max_length = 1 << 30
|
1383
|
-
skip_names = set(
|
1384
|
-
[
|
1385
|
-
"text",
|
1386
|
-
"input_ids",
|
1387
|
-
"input_embeds",
|
1388
|
-
"image_data",
|
1389
|
-
"audio_data",
|
1390
|
-
"lora_path",
|
1391
|
-
"sampling_params",
|
1392
|
-
]
|
1393
|
-
)
|
1394
|
-
out_skip_names = set(
|
1395
|
-
[
|
1396
|
-
"text",
|
1397
|
-
"output_ids",
|
1398
|
-
"embedding",
|
1399
|
-
]
|
1400
|
-
)
|
1401
|
-
elif self.log_requests_level == 1:
|
1402
|
-
max_length = 1 << 30
|
1403
|
-
skip_names = set(
|
1404
|
-
[
|
1405
|
-
"text",
|
1406
|
-
"input_ids",
|
1407
|
-
"input_embeds",
|
1408
|
-
"image_data",
|
1409
|
-
"audio_data",
|
1410
|
-
"lora_path",
|
1411
|
-
]
|
1412
|
-
)
|
1413
|
-
out_skip_names = set(
|
1414
|
-
[
|
1415
|
-
"text",
|
1416
|
-
"output_ids",
|
1417
|
-
"embedding",
|
1418
|
-
]
|
1419
|
-
)
|
1420
|
-
elif self.log_requests_level == 2:
|
1421
|
-
max_length = 2048
|
1422
|
-
elif self.log_requests_level == 3:
|
1423
|
-
max_length = 1 << 30
|
1424
|
-
else:
|
1425
|
-
raise ValueError(
|
1426
|
-
f"Invalid --log-requests-level: {self.log_requests_level=}"
|
1427
|
-
)
|
1428
|
-
return max_length, skip_names, out_skip_names
|
1429
|
-
|
1430
1107
|
def configure_logging(self, obj: ConfigureLoggingReq):
|
1431
1108
|
if obj.log_requests is not None:
|
1432
1109
|
self.log_requests = obj.log_requests
|
@@ -1491,6 +1168,9 @@ class TokenizerManager:
|
|
1491
1168
|
self.asyncio_tasks.add(
|
1492
1169
|
loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
|
1493
1170
|
)
|
1171
|
+
self.asyncio_tasks.add(
|
1172
|
+
loop.create_task(print_exception_wrapper(self.watch_load_thread))
|
1173
|
+
)
|
1494
1174
|
|
1495
1175
|
def dump_requests_before_crash(self):
|
1496
1176
|
if self.crash_dump_performed:
|
@@ -1582,12 +1262,12 @@ class TokenizerManager:
|
|
1582
1262
|
# Drain requests
|
1583
1263
|
while True:
|
1584
1264
|
remain_num_req = len(self.rid_to_state)
|
1265
|
+
remaining_rids = list(self.rid_to_state.keys())
|
1585
1266
|
|
1586
1267
|
if self.server_status == ServerStatus.UnHealthy:
|
1587
1268
|
# if health check failed, we should exit immediately
|
1588
1269
|
logger.error(
|
1589
|
-
"Signal SIGTERM received while health check failed.
|
1590
|
-
remain_num_req,
|
1270
|
+
"Signal SIGTERM received while health check failed. Force exiting."
|
1591
1271
|
)
|
1592
1272
|
self.dump_requests_before_crash()
|
1593
1273
|
break
|
@@ -1595,13 +1275,12 @@ class TokenizerManager:
|
|
1595
1275
|
elif get_bool_env_var("SGL_FORCE_SHUTDOWN"):
|
1596
1276
|
# if force shutdown flag set, exit immediately
|
1597
1277
|
logger.error(
|
1598
|
-
"Signal SIGTERM received while force shutdown flag set. Force exiting
|
1599
|
-
remain_num_req,
|
1278
|
+
"Signal SIGTERM received while force shutdown flag set. Force exiting."
|
1600
1279
|
)
|
1601
1280
|
break
|
1602
1281
|
|
1603
1282
|
logger.info(
|
1604
|
-
f"Gracefully exiting...
|
1283
|
+
f"Gracefully exiting... Remaining number of requests {remain_num_req}. Remaining requests {remaining_rids=}."
|
1605
1284
|
)
|
1606
1285
|
if remain_num_req > 0:
|
1607
1286
|
await asyncio.sleep(5)
|
@@ -1622,7 +1301,10 @@ class TokenizerManager:
|
|
1622
1301
|
def _handle_batch_output(
|
1623
1302
|
self,
|
1624
1303
|
recv_obj: Union[
|
1625
|
-
|
1304
|
+
BatchStrOutput,
|
1305
|
+
BatchEmbeddingOutput,
|
1306
|
+
BatchMultimodalOutput,
|
1307
|
+
BatchTokenIDOutput,
|
1626
1308
|
],
|
1627
1309
|
):
|
1628
1310
|
for i, rid in enumerate(recv_obj.rids):
|
@@ -1656,7 +1338,7 @@ class TokenizerManager:
|
|
1656
1338
|
i,
|
1657
1339
|
)
|
1658
1340
|
|
1659
|
-
if not isinstance(recv_obj,
|
1341
|
+
if not isinstance(recv_obj, BatchEmbeddingOutput):
|
1660
1342
|
meta_info.update(
|
1661
1343
|
{
|
1662
1344
|
"completion_tokens": recv_obj.completion_tokens[i],
|
@@ -1667,7 +1349,7 @@ class TokenizerManager:
|
|
1667
1349
|
if getattr(recv_obj, "output_hidden_states", None):
|
1668
1350
|
meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
|
1669
1351
|
|
1670
|
-
if isinstance(recv_obj,
|
1352
|
+
if isinstance(recv_obj, BatchStrOutput):
|
1671
1353
|
state.text += recv_obj.output_strs[i]
|
1672
1354
|
if state.obj.stream:
|
1673
1355
|
state.output_ids.extend(recv_obj.output_ids[i])
|
@@ -1682,7 +1364,7 @@ class TokenizerManager:
|
|
1682
1364
|
"output_ids": output_token_ids,
|
1683
1365
|
"meta_info": meta_info,
|
1684
1366
|
}
|
1685
|
-
elif isinstance(recv_obj,
|
1367
|
+
elif isinstance(recv_obj, BatchTokenIDOutput):
|
1686
1368
|
if self.server_args.stream_output and state.obj.stream:
|
1687
1369
|
state.output_ids.extend(recv_obj.output_ids[i])
|
1688
1370
|
output_token_ids = state.output_ids[state.last_output_offset :]
|
@@ -1695,10 +1377,10 @@ class TokenizerManager:
|
|
1695
1377
|
"output_ids": output_token_ids,
|
1696
1378
|
"meta_info": meta_info,
|
1697
1379
|
}
|
1698
|
-
elif isinstance(recv_obj,
|
1380
|
+
elif isinstance(recv_obj, BatchMultimodalOutput):
|
1699
1381
|
raise NotImplementedError("BatchMultimodalOut not implemented")
|
1700
1382
|
else:
|
1701
|
-
assert isinstance(recv_obj,
|
1383
|
+
assert isinstance(recv_obj, BatchEmbeddingOutput)
|
1702
1384
|
out_dict = {
|
1703
1385
|
"embedding": recv_obj.embeddings[i],
|
1704
1386
|
"meta_info": meta_info,
|
@@ -1710,6 +1392,9 @@ class TokenizerManager:
|
|
1710
1392
|
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
1711
1393
|
state.finished_time = time.time()
|
1712
1394
|
meta_info["e2e_latency"] = state.finished_time - state.created_time
|
1395
|
+
|
1396
|
+
trace_req_finish(rid, ts=int(state.finished_time * 1e9))
|
1397
|
+
|
1713
1398
|
del self.rid_to_state[rid]
|
1714
1399
|
|
1715
1400
|
# Mark ongoing LoRA request as finished.
|
@@ -1734,7 +1419,7 @@ class TokenizerManager:
|
|
1734
1419
|
top_logprobs_num: int,
|
1735
1420
|
token_ids_logprob: List[int],
|
1736
1421
|
return_text_in_logprobs: bool,
|
1737
|
-
recv_obj:
|
1422
|
+
recv_obj: BatchStrOutput,
|
1738
1423
|
recv_obj_index: int,
|
1739
1424
|
):
|
1740
1425
|
if recv_obj.input_token_logprobs_val is None:
|
@@ -1852,13 +1537,19 @@ class TokenizerManager:
|
|
1852
1537
|
ret.append(None)
|
1853
1538
|
return ret
|
1854
1539
|
|
1855
|
-
def collect_metrics(self, state: ReqState, recv_obj:
|
1540
|
+
def collect_metrics(self, state: ReqState, recv_obj: BatchStrOutput, i: int):
|
1856
1541
|
completion_tokens = (
|
1857
1542
|
recv_obj.completion_tokens[i]
|
1858
1543
|
if getattr(recv_obj, "completion_tokens", None)
|
1859
1544
|
else 0
|
1860
1545
|
)
|
1861
1546
|
|
1547
|
+
custom_labels = getattr(state.obj, "custom_labels", None)
|
1548
|
+
labels = (
|
1549
|
+
{**self.metrics_collector.labels, **custom_labels}
|
1550
|
+
if custom_labels
|
1551
|
+
else self.metrics_collector.labels
|
1552
|
+
)
|
1862
1553
|
if (
|
1863
1554
|
state.first_token_time == 0.0
|
1864
1555
|
and self.disaggregation_mode != DisaggregationMode.PREFILL
|
@@ -1866,7 +1557,7 @@ class TokenizerManager:
|
|
1866
1557
|
state.first_token_time = state.last_time = time.time()
|
1867
1558
|
state.last_completion_tokens = completion_tokens
|
1868
1559
|
self.metrics_collector.observe_time_to_first_token(
|
1869
|
-
state.first_token_time - state.created_time
|
1560
|
+
labels, state.first_token_time - state.created_time
|
1870
1561
|
)
|
1871
1562
|
else:
|
1872
1563
|
num_new_tokens = completion_tokens - state.last_completion_tokens
|
@@ -1874,6 +1565,7 @@ class TokenizerManager:
|
|
1874
1565
|
new_time = time.time()
|
1875
1566
|
interval = new_time - state.last_time
|
1876
1567
|
self.metrics_collector.observe_inter_token_latency(
|
1568
|
+
labels,
|
1877
1569
|
interval,
|
1878
1570
|
num_new_tokens,
|
1879
1571
|
)
|
@@ -1888,6 +1580,7 @@ class TokenizerManager:
|
|
1888
1580
|
or state.obj.sampling_params.get("structural_tag", None)
|
1889
1581
|
)
|
1890
1582
|
self.metrics_collector.observe_one_finished_request(
|
1583
|
+
labels,
|
1891
1584
|
recv_obj.prompt_tokens[i],
|
1892
1585
|
completion_tokens,
|
1893
1586
|
recv_obj.cached_tokens[i],
|
@@ -1940,7 +1633,7 @@ class TokenizerManager:
|
|
1940
1633
|
|
1941
1634
|
asyncio.create_task(asyncio.to_thread(background_task))
|
1942
1635
|
|
1943
|
-
def _handle_abort_req(self, recv_obj):
|
1636
|
+
def _handle_abort_req(self, recv_obj: AbortReq):
|
1944
1637
|
if is_health_check_generate_req(recv_obj):
|
1945
1638
|
return
|
1946
1639
|
state = self.rid_to_state[recv_obj.rid]
|
@@ -2059,11 +1752,15 @@ class TokenizerManager:
|
|
2059
1752
|
# the next position after the last token in the prompt
|
2060
1753
|
output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
|
2061
1754
|
|
2062
|
-
#
|
2063
|
-
if
|
1755
|
+
# Check if output_logprobs is properly populated
|
1756
|
+
if (
|
1757
|
+
output_logprobs is None
|
1758
|
+
or not output_logprobs
|
1759
|
+
or len(output_logprobs) == 0
|
1760
|
+
):
|
2064
1761
|
raise RuntimeError(
|
2065
|
-
f"output_logprobs is
|
2066
|
-
"This
|
1762
|
+
f"output_logprobs is empty for request {result['meta_info'].get('id', '<unknown>')}. "
|
1763
|
+
"This indicates token_ids_logprobs were not computed properly for the scoring request."
|
2067
1764
|
)
|
2068
1765
|
|
2069
1766
|
for logprob, token_id, _ in output_logprobs[0]:
|
@@ -2088,6 +1785,43 @@ class TokenizerManager:
|
|
2088
1785
|
|
2089
1786
|
return scores
|
2090
1787
|
|
1788
|
+
async def watch_load_thread(self):
|
1789
|
+
# Only for dp_controller when dp_size > 1
|
1790
|
+
if (
|
1791
|
+
self.server_args.dp_size == 1
|
1792
|
+
or self.server_args.load_balance_method == "round_robin"
|
1793
|
+
):
|
1794
|
+
return
|
1795
|
+
|
1796
|
+
while True:
|
1797
|
+
await asyncio.sleep(self.server_args.load_watch_interval)
|
1798
|
+
loads = await self.get_load_communicator(GetLoadReqInput())
|
1799
|
+
load_udpate_req = WatchLoadUpdateReq(loads=loads)
|
1800
|
+
self.send_to_scheduler.send_pyobj(load_udpate_req)
|
1801
|
+
|
1802
|
+
def _trace_request_start(
|
1803
|
+
self,
|
1804
|
+
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
1805
|
+
created_time: Optional[float] = None,
|
1806
|
+
):
|
1807
|
+
if obj.is_single:
|
1808
|
+
bootstrap_room = (
|
1809
|
+
obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None
|
1810
|
+
)
|
1811
|
+
trace_req_start(obj.rid, bootstrap_room, ts=int(created_time * 1e9))
|
1812
|
+
trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True)
|
1813
|
+
else:
|
1814
|
+
for i in range(len(obj.rid)):
|
1815
|
+
bootstrap_room = (
|
1816
|
+
obj.bootstrap_room[i]
|
1817
|
+
if hasattr(obj, "bootstrap_room") and obj.bootstrap_room
|
1818
|
+
else None
|
1819
|
+
)
|
1820
|
+
trace_req_start(obj.rid[i], bootstrap_room, ts=int(created_time * 1e9))
|
1821
|
+
trace_slice_start(
|
1822
|
+
"", obj.rid[i], ts=int(created_time * 1e9), anonymous=True
|
1823
|
+
)
|
1824
|
+
|
2091
1825
|
|
2092
1826
|
class ServerStatus(Enum):
|
2093
1827
|
Up = "Up"
|
@@ -2133,57 +1867,12 @@ class SignalHandler:
|
|
2133
1867
|
|
2134
1868
|
def running_phase_sigquit_handler(self, signum=None, frame=None):
|
2135
1869
|
logger.error(
|
2136
|
-
"
|
1870
|
+
f"SIGQUIT received. {signum=}, {frame=}. It usually means one child failed."
|
2137
1871
|
)
|
2138
1872
|
self.tokenizer_manager.dump_requests_before_crash()
|
2139
1873
|
kill_process_tree(os.getpid())
|
2140
1874
|
|
2141
1875
|
|
2142
|
-
T = TypeVar("T")
|
2143
|
-
|
2144
|
-
|
2145
|
-
class _Communicator(Generic[T]):
|
2146
|
-
"""Note: The communicator now only run up to 1 in-flight request at any time."""
|
2147
|
-
|
2148
|
-
enable_multi_tokenizer = False
|
2149
|
-
|
2150
|
-
def __init__(self, sender, fan_out: int):
|
2151
|
-
self._sender = sender
|
2152
|
-
self._fan_out = fan_out
|
2153
|
-
self._result_event: Optional[asyncio.Event] = None
|
2154
|
-
self._result_values: Optional[List[T]] = None
|
2155
|
-
self._ready_queue: Deque[asyncio.Future] = deque()
|
2156
|
-
|
2157
|
-
async def __call__(self, obj):
|
2158
|
-
ready_event = asyncio.Event()
|
2159
|
-
if self._result_event is not None or len(self._ready_queue) > 0:
|
2160
|
-
self._ready_queue.append(ready_event)
|
2161
|
-
await ready_event.wait()
|
2162
|
-
assert self._result_event is None
|
2163
|
-
assert self._result_values is None
|
2164
|
-
|
2165
|
-
if obj:
|
2166
|
-
if _Communicator.enable_multi_tokenizer:
|
2167
|
-
obj = MultiTokenizerWarpper(worker_id=os.getpid(), obj=obj)
|
2168
|
-
self._sender.send_pyobj(obj)
|
2169
|
-
|
2170
|
-
self._result_event = asyncio.Event()
|
2171
|
-
self._result_values = []
|
2172
|
-
await self._result_event.wait()
|
2173
|
-
result_values = self._result_values
|
2174
|
-
self._result_event = self._result_values = None
|
2175
|
-
|
2176
|
-
if len(self._ready_queue) > 0:
|
2177
|
-
self._ready_queue.popleft().set()
|
2178
|
-
|
2179
|
-
return result_values
|
2180
|
-
|
2181
|
-
def handle_recv(self, recv_obj: T):
|
2182
|
-
self._result_values.append(recv_obj)
|
2183
|
-
if len(self._result_values) == self._fan_out:
|
2184
|
-
self._result_event.set()
|
2185
|
-
|
2186
|
-
|
2187
1876
|
# Note: request abort handling logic
|
2188
1877
|
# We should handle all of the following cases correctly.
|
2189
1878
|
#
|