sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -31,18 +31,7 @@ from contextlib import nullcontext
|
|
31
31
|
from datetime import datetime
|
32
32
|
from enum import Enum
|
33
33
|
from http import HTTPStatus
|
34
|
-
from typing import
|
35
|
-
Any,
|
36
|
-
Awaitable,
|
37
|
-
Deque,
|
38
|
-
Dict,
|
39
|
-
Generic,
|
40
|
-
List,
|
41
|
-
Optional,
|
42
|
-
Tuple,
|
43
|
-
TypeVar,
|
44
|
-
Union,
|
45
|
-
)
|
34
|
+
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
|
46
35
|
|
47
36
|
import fastapi
|
48
37
|
import torch
|
@@ -53,80 +42,49 @@ from fastapi import BackgroundTasks
|
|
53
42
|
|
54
43
|
from sglang.srt.aio_rwlock import RWLock
|
55
44
|
from sglang.srt.configs.model_config import ModelConfig
|
56
|
-
from sglang.srt.disaggregation.utils import
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
get_kv_class,
|
61
|
-
)
|
62
|
-
from sglang.srt.hf_transformers_utils import (
|
63
|
-
get_processor,
|
64
|
-
get_tokenizer,
|
65
|
-
get_tokenizer_from_processor,
|
66
|
-
)
|
67
|
-
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
|
45
|
+
from sglang.srt.disaggregation.utils import DisaggregationMode
|
46
|
+
from sglang.srt.lora.lora_registry import LoRARegistry
|
47
|
+
from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer
|
48
|
+
from sglang.srt.managers.disagg_service import start_disagg_service
|
68
49
|
from sglang.srt.managers.io_struct import (
|
69
50
|
AbortReq,
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
51
|
+
BatchEmbeddingOutput,
|
52
|
+
BatchMultimodalOutput,
|
53
|
+
BatchStrOutput,
|
54
|
+
BatchTokenIDOutput,
|
74
55
|
BatchTokenizedEmbeddingReqInput,
|
75
56
|
BatchTokenizedGenerateReqInput,
|
76
|
-
ClearHiCacheReqInput,
|
77
|
-
ClearHiCacheReqOutput,
|
78
|
-
CloseSessionReqInput,
|
79
57
|
ConfigureLoggingReq,
|
80
58
|
EmbeddingReqInput,
|
81
|
-
ExpertDistributionReq,
|
82
|
-
ExpertDistributionReqOutput,
|
83
|
-
FlushCacheReqInput,
|
84
|
-
FlushCacheReqOutput,
|
85
59
|
FreezeGCReq,
|
86
60
|
GenerateReqInput,
|
87
|
-
|
88
|
-
GetInternalStateReqOutput,
|
89
|
-
GetWeightsByNameReqInput,
|
90
|
-
GetWeightsByNameReqOutput,
|
61
|
+
GetLoadReqInput,
|
91
62
|
HealthCheckOutput,
|
92
|
-
|
93
|
-
InitWeightsUpdateGroupReqOutput,
|
94
|
-
LoadLoRAAdapterReqInput,
|
95
|
-
LoadLoRAAdapterReqOutput,
|
96
|
-
LoRAUpdateResult,
|
97
|
-
MultiTokenizerWarpper,
|
98
|
-
OpenSessionReqInput,
|
63
|
+
MultiTokenizerWrapper,
|
99
64
|
OpenSessionReqOutput,
|
100
|
-
ProfileReq,
|
101
|
-
ProfileReqOutput,
|
102
|
-
ProfileReqType,
|
103
|
-
ReleaseMemoryOccupationReqInput,
|
104
|
-
ReleaseMemoryOccupationReqOutput,
|
105
|
-
ResumeMemoryOccupationReqInput,
|
106
|
-
ResumeMemoryOccupationReqOutput,
|
107
65
|
SessionParams,
|
108
|
-
SetInternalStateReq,
|
109
|
-
SetInternalStateReqOutput,
|
110
|
-
SlowDownReqInput,
|
111
|
-
SlowDownReqOutput,
|
112
66
|
TokenizedEmbeddingReqInput,
|
113
67
|
TokenizedGenerateReqInput,
|
114
|
-
UnloadLoRAAdapterReqInput,
|
115
|
-
UnloadLoRAAdapterReqOutput,
|
116
68
|
UpdateWeightFromDiskReqInput,
|
117
69
|
UpdateWeightFromDiskReqOutput,
|
118
|
-
|
119
|
-
UpdateWeightsFromDistributedReqOutput,
|
120
|
-
UpdateWeightsFromTensorReqInput,
|
121
|
-
UpdateWeightsFromTensorReqOutput,
|
70
|
+
WatchLoadUpdateReq,
|
122
71
|
)
|
123
72
|
from sglang.srt.managers.mm_utils import TensorTransportMode
|
124
73
|
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
|
125
74
|
from sglang.srt.managers.scheduler import is_health_check_generate_req
|
126
75
|
from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
|
76
|
+
from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicatorMixin
|
127
77
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
128
78
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
129
79
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
80
|
+
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
81
|
+
from sglang.srt.tracing.trace import (
|
82
|
+
trace_get_proc_propagate_context,
|
83
|
+
trace_req_finish,
|
84
|
+
trace_req_start,
|
85
|
+
trace_slice_end,
|
86
|
+
trace_slice_start,
|
87
|
+
)
|
130
88
|
from sglang.srt.utils import (
|
131
89
|
configure_gc_warning,
|
132
90
|
dataclass_to_string_truncated,
|
@@ -136,6 +94,11 @@ from sglang.srt.utils import (
|
|
136
94
|
get_zmq_socket,
|
137
95
|
kill_process_tree,
|
138
96
|
)
|
97
|
+
from sglang.srt.utils.hf_transformers_utils import (
|
98
|
+
get_processor,
|
99
|
+
get_tokenizer,
|
100
|
+
get_tokenizer_from_processor,
|
101
|
+
)
|
139
102
|
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
140
103
|
|
141
104
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
@@ -180,7 +143,7 @@ class ReqState:
|
|
180
143
|
output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
|
181
144
|
|
182
145
|
|
183
|
-
class TokenizerManager:
|
146
|
+
class TokenizerManager(TokenizerCommunicatorMixin):
|
184
147
|
"""TokenizerManager is a process that tokenizes the text."""
|
185
148
|
|
186
149
|
def __init__(
|
@@ -199,6 +162,7 @@ class TokenizerManager:
|
|
199
162
|
else None
|
200
163
|
)
|
201
164
|
self.crash_dump_folder = server_args.crash_dump_folder
|
165
|
+
self.enable_trace = server_args.enable_trace
|
202
166
|
|
203
167
|
# Read model args
|
204
168
|
self.model_path = server_args.model_path
|
@@ -210,8 +174,17 @@ class TokenizerManager:
|
|
210
174
|
self.image_token_id = self.model_config.image_token_id
|
211
175
|
self.max_req_input_len = None # Will be set later in engine.py
|
212
176
|
|
177
|
+
speculative_algorithm = SpeculativeAlgorithm.from_string(
|
178
|
+
server_args.speculative_algorithm
|
179
|
+
)
|
180
|
+
self.reserve_input_token_num = (
|
181
|
+
0
|
182
|
+
if speculative_algorithm.is_none()
|
183
|
+
else server_args.speculative_num_draft_tokens
|
184
|
+
)
|
185
|
+
|
213
186
|
if self.model_config.is_multimodal:
|
214
|
-
import_processors()
|
187
|
+
import_processors("sglang.srt.multimodal.processors")
|
215
188
|
try:
|
216
189
|
_processor = get_processor(
|
217
190
|
server_args.tokenizer_path,
|
@@ -262,6 +235,18 @@ class TokenizerManager:
|
|
262
235
|
trust_remote_code=server_args.trust_remote_code,
|
263
236
|
revision=server_args.revision,
|
264
237
|
)
|
238
|
+
# Initialize async dynamic batch tokenizer if enabled (common for both multimodal and non-multimodal)
|
239
|
+
if (
|
240
|
+
server_args.enable_dynamic_batch_tokenizer
|
241
|
+
and not server_args.skip_tokenizer_init
|
242
|
+
):
|
243
|
+
self.async_dynamic_batch_tokenizer = AsyncDynamicbatchTokenizer(
|
244
|
+
self.tokenizer,
|
245
|
+
max_batch_size=server_args.dynamic_batch_tokenizer_batch_size,
|
246
|
+
batch_wait_timeout_s=server_args.dynamic_batch_tokenizer_batch_timeout,
|
247
|
+
)
|
248
|
+
else:
|
249
|
+
self.async_dynamic_batch_tokenizer = None
|
265
250
|
|
266
251
|
# Init inter-process communication
|
267
252
|
context = zmq.asyncio.Context(2)
|
@@ -319,8 +304,10 @@ class TokenizerManager:
|
|
319
304
|
# LoRA updates and inference to overlap.
|
320
305
|
self.lora_update_lock = asyncio.Lock()
|
321
306
|
|
322
|
-
|
323
|
-
|
307
|
+
self.disaggregation_mode = DisaggregationMode(
|
308
|
+
self.server_args.disaggregation_mode
|
309
|
+
)
|
310
|
+
self.bootstrap_server = start_disagg_service(self.server_args)
|
324
311
|
|
325
312
|
# For load balancing
|
326
313
|
self.current_load = 0
|
@@ -328,12 +315,16 @@ class TokenizerManager:
|
|
328
315
|
|
329
316
|
# Metrics
|
330
317
|
if self.enable_metrics:
|
318
|
+
labels = {
|
319
|
+
"model_name": self.server_args.served_model_name,
|
320
|
+
# TODO: Add lora name/path in the future,
|
321
|
+
}
|
322
|
+
if server_args.tokenizer_metrics_allowed_custom_labels:
|
323
|
+
for label in server_args.tokenizer_metrics_allowed_custom_labels:
|
324
|
+
labels[label] = ""
|
331
325
|
self.metrics_collector = TokenizerMetricsCollector(
|
332
326
|
server_args=server_args,
|
333
|
-
labels=
|
334
|
-
"model_name": self.server_args.served_model_name,
|
335
|
-
# TODO: Add lora name/path in the future,
|
336
|
-
},
|
327
|
+
labels=labels,
|
337
328
|
bucket_time_to_first_token=self.server_args.bucket_time_to_first_token,
|
338
329
|
bucket_e2e_request_latency=self.server_args.bucket_e2e_request_latency,
|
339
330
|
bucket_inter_token_latency=self.server_args.bucket_inter_token_latency,
|
@@ -344,58 +335,14 @@ class TokenizerManager:
|
|
344
335
|
if self.server_args.gc_warning_threshold_secs > 0.0:
|
345
336
|
configure_gc_warning(self.server_args.gc_warning_threshold_secs)
|
346
337
|
|
347
|
-
# Communicators
|
348
|
-
self.init_weights_update_group_communicator = _Communicator(
|
349
|
-
self.send_to_scheduler, server_args.dp_size
|
350
|
-
)
|
351
|
-
self.update_weights_from_distributed_communicator = _Communicator(
|
352
|
-
self.send_to_scheduler, server_args.dp_size
|
353
|
-
)
|
354
|
-
self.update_weights_from_tensor_communicator = _Communicator(
|
355
|
-
self.send_to_scheduler, server_args.dp_size
|
356
|
-
)
|
357
|
-
self.get_weights_by_name_communicator = _Communicator(
|
358
|
-
self.send_to_scheduler, server_args.dp_size
|
359
|
-
)
|
360
|
-
self.release_memory_occupation_communicator = _Communicator(
|
361
|
-
self.send_to_scheduler, server_args.dp_size
|
362
|
-
)
|
363
|
-
self.resume_memory_occupation_communicator = _Communicator(
|
364
|
-
self.send_to_scheduler, server_args.dp_size
|
365
|
-
)
|
366
|
-
self.slow_down_communicator = _Communicator(
|
367
|
-
self.send_to_scheduler, server_args.dp_size
|
368
|
-
)
|
369
|
-
self.flush_cache_communicator = _Communicator(
|
370
|
-
self.send_to_scheduler, server_args.dp_size
|
371
|
-
)
|
372
|
-
self.clear_hicache_storage_communicator = _Communicator(
|
373
|
-
self.send_to_scheduler, server_args.dp_size
|
374
|
-
)
|
375
|
-
self.profile_communicator = _Communicator(
|
376
|
-
self.send_to_scheduler, server_args.dp_size
|
377
|
-
)
|
378
|
-
self.get_internal_state_communicator = _Communicator(
|
379
|
-
self.send_to_scheduler, server_args.dp_size
|
380
|
-
)
|
381
|
-
self.set_internal_state_communicator = _Communicator(
|
382
|
-
self.send_to_scheduler, server_args.dp_size
|
383
|
-
)
|
384
|
-
self.expert_distribution_communicator = _Communicator(
|
385
|
-
self.send_to_scheduler, server_args.dp_size
|
386
|
-
)
|
387
|
-
self.update_lora_adapter_communicator = _Communicator(
|
388
|
-
self.send_to_scheduler, server_args.dp_size
|
389
|
-
)
|
390
|
-
|
391
338
|
self._result_dispatcher = TypeBasedDispatcher(
|
392
339
|
[
|
393
340
|
(
|
394
341
|
(
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
342
|
+
BatchStrOutput,
|
343
|
+
BatchEmbeddingOutput,
|
344
|
+
BatchTokenIDOutput,
|
345
|
+
BatchMultimodalOutput,
|
399
346
|
),
|
400
347
|
self._handle_batch_output,
|
401
348
|
),
|
@@ -405,100 +352,15 @@ class TokenizerManager:
|
|
405
352
|
UpdateWeightFromDiskReqOutput,
|
406
353
|
self._handle_update_weights_from_disk_req_output,
|
407
354
|
),
|
408
|
-
(
|
409
|
-
InitWeightsUpdateGroupReqOutput,
|
410
|
-
self.init_weights_update_group_communicator.handle_recv,
|
411
|
-
),
|
412
|
-
(
|
413
|
-
UpdateWeightsFromDistributedReqOutput,
|
414
|
-
self.update_weights_from_distributed_communicator.handle_recv,
|
415
|
-
),
|
416
|
-
(
|
417
|
-
UpdateWeightsFromTensorReqOutput,
|
418
|
-
self.update_weights_from_tensor_communicator.handle_recv,
|
419
|
-
),
|
420
|
-
(
|
421
|
-
GetWeightsByNameReqOutput,
|
422
|
-
self.get_weights_by_name_communicator.handle_recv,
|
423
|
-
),
|
424
|
-
(
|
425
|
-
ReleaseMemoryOccupationReqOutput,
|
426
|
-
self.release_memory_occupation_communicator.handle_recv,
|
427
|
-
),
|
428
|
-
(
|
429
|
-
ResumeMemoryOccupationReqOutput,
|
430
|
-
self.resume_memory_occupation_communicator.handle_recv,
|
431
|
-
),
|
432
|
-
(
|
433
|
-
SlowDownReqOutput,
|
434
|
-
self.slow_down_communicator.handle_recv,
|
435
|
-
),
|
436
|
-
(
|
437
|
-
ClearHiCacheReqOutput,
|
438
|
-
self.clear_hicache_storage_communicator.handle_recv,
|
439
|
-
),
|
440
|
-
(
|
441
|
-
FlushCacheReqOutput,
|
442
|
-
self.flush_cache_communicator.handle_recv,
|
443
|
-
),
|
444
|
-
(
|
445
|
-
ProfileReqOutput,
|
446
|
-
self.profile_communicator.handle_recv,
|
447
|
-
),
|
448
355
|
(
|
449
356
|
FreezeGCReq,
|
450
357
|
lambda x: None,
|
451
358
|
), # For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it.
|
452
|
-
(
|
453
|
-
GetInternalStateReqOutput,
|
454
|
-
self.get_internal_state_communicator.handle_recv,
|
455
|
-
),
|
456
|
-
(
|
457
|
-
SetInternalStateReqOutput,
|
458
|
-
self.set_internal_state_communicator.handle_recv,
|
459
|
-
),
|
460
|
-
(
|
461
|
-
ExpertDistributionReqOutput,
|
462
|
-
self.expert_distribution_communicator.handle_recv,
|
463
|
-
),
|
464
|
-
(
|
465
|
-
LoRAUpdateResult,
|
466
|
-
self.update_lora_adapter_communicator.handle_recv,
|
467
|
-
),
|
468
359
|
(HealthCheckOutput, lambda x: None),
|
469
360
|
]
|
470
361
|
)
|
471
362
|
|
472
|
-
|
473
|
-
self.disaggregation_mode = DisaggregationMode(
|
474
|
-
self.server_args.disaggregation_mode
|
475
|
-
)
|
476
|
-
self.disaggregation_transfer_backend = TransferBackend(
|
477
|
-
self.server_args.disaggregation_transfer_backend
|
478
|
-
)
|
479
|
-
# Start kv boostrap server on prefill
|
480
|
-
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
481
|
-
# only start bootstrap server on prefill tm
|
482
|
-
kv_bootstrap_server_class = get_kv_class(
|
483
|
-
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
484
|
-
)
|
485
|
-
self.bootstrap_server = kv_bootstrap_server_class(
|
486
|
-
self.server_args.disaggregation_bootstrap_port
|
487
|
-
)
|
488
|
-
is_create_store = (
|
489
|
-
self.server_args.node_rank == 0
|
490
|
-
and self.server_args.disaggregation_transfer_backend == "ascend"
|
491
|
-
)
|
492
|
-
if is_create_store:
|
493
|
-
try:
|
494
|
-
from mf_adapter import create_config_store
|
495
|
-
|
496
|
-
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
|
497
|
-
create_config_store(ascend_url)
|
498
|
-
except Exception as e:
|
499
|
-
error_message = f"Failed create mf store, invalid ascend_url."
|
500
|
-
error_message += f" With exception {e}"
|
501
|
-
raise error_message
|
363
|
+
self.init_communicators(server_args)
|
502
364
|
|
503
365
|
async def generate_request(
|
504
366
|
self,
|
@@ -518,6 +380,9 @@ class TokenizerManager:
|
|
518
380
|
# If it's a single value, add worker_id prefix
|
519
381
|
obj.rid = f"{self.worker_id}_{obj.rid}"
|
520
382
|
|
383
|
+
if self.enable_trace:
|
384
|
+
self._trace_request_start(obj, created_time)
|
385
|
+
|
521
386
|
if self.log_requests:
|
522
387
|
max_length, skip_names, _ = self.log_request_metadata
|
523
388
|
logger.info(
|
@@ -543,6 +408,144 @@ class TokenizerManager:
|
|
543
408
|
):
|
544
409
|
yield response
|
545
410
|
|
411
|
+
def _detect_input_format(
|
412
|
+
self, texts: Union[str, List[str]], is_cross_encoder: bool
|
413
|
+
) -> str:
|
414
|
+
"""Detect the format of input texts for proper tokenization handling.
|
415
|
+
|
416
|
+
Returns:
|
417
|
+
- "single_string": Regular single text like "Hello world"
|
418
|
+
- "batch_strings": Regular batch like ["Hello", "World"]
|
419
|
+
- "cross_encoder_pairs": Cross-encoder pairs like [["query", "document"]]
|
420
|
+
"""
|
421
|
+
if isinstance(texts, str):
|
422
|
+
return "single_string"
|
423
|
+
|
424
|
+
if (
|
425
|
+
is_cross_encoder
|
426
|
+
and len(texts) > 0
|
427
|
+
and isinstance(texts[0], list)
|
428
|
+
and len(texts[0]) == 2
|
429
|
+
):
|
430
|
+
return "cross_encoder_pairs"
|
431
|
+
|
432
|
+
return "batch_strings"
|
433
|
+
|
434
|
+
def _prepare_tokenizer_input(
|
435
|
+
self, texts: Union[str, List[str]], input_format: str
|
436
|
+
) -> Union[List[str], List[List[str]]]:
|
437
|
+
"""Prepare input for the tokenizer based on detected format."""
|
438
|
+
if input_format == "single_string":
|
439
|
+
return [texts] # Wrap single string for batch processing
|
440
|
+
elif input_format == "cross_encoder_pairs":
|
441
|
+
return texts # Already in correct format: [["query", "doc"]]
|
442
|
+
else: # batch_strings
|
443
|
+
return texts # Already in correct format: ["text1", "text2"]
|
444
|
+
|
445
|
+
def _extract_tokenizer_results(
|
446
|
+
self,
|
447
|
+
input_ids: List[List[int]],
|
448
|
+
token_type_ids: Optional[List[List[int]]],
|
449
|
+
input_format: str,
|
450
|
+
original_batch_size: int,
|
451
|
+
) -> Union[
|
452
|
+
Tuple[List[int], Optional[List[int]]],
|
453
|
+
Tuple[List[List[int]], Optional[List[List[int]]]],
|
454
|
+
]:
|
455
|
+
"""Extract results from tokenizer output based on input format."""
|
456
|
+
|
457
|
+
# For single inputs (string or single cross-encoder pair), extract first element
|
458
|
+
if (
|
459
|
+
input_format in ["single_string", "cross_encoder_pairs"]
|
460
|
+
and original_batch_size == 1
|
461
|
+
):
|
462
|
+
single_input_ids = input_ids[0] if input_ids else []
|
463
|
+
single_token_type_ids = token_type_ids[0] if token_type_ids else None
|
464
|
+
return single_input_ids, single_token_type_ids
|
465
|
+
|
466
|
+
# For true batches, return as-is
|
467
|
+
return input_ids, token_type_ids
|
468
|
+
|
469
|
+
async def _tokenize_texts(
|
470
|
+
self, texts: Union[str, List[str]], is_cross_encoder: bool = False
|
471
|
+
) -> Union[
|
472
|
+
Tuple[List[int], Optional[List[int]]],
|
473
|
+
Tuple[List[List[int]], Optional[List[List[int]]]],
|
474
|
+
]:
|
475
|
+
"""
|
476
|
+
Tokenize text(s) using the appropriate tokenizer strategy.
|
477
|
+
|
478
|
+
This method handles multiple input formats and chooses between async dynamic
|
479
|
+
batch tokenizer (for single texts only) and regular tokenizer.
|
480
|
+
|
481
|
+
Args:
|
482
|
+
texts: Text input in various formats:
|
483
|
+
|
484
|
+
Regular cases:
|
485
|
+
- Single string: "How are you?"
|
486
|
+
- Batch of strings: ["Hello", "World", "How are you?"]
|
487
|
+
|
488
|
+
Cross-encoder cases (sentence pairs for similarity/ranking):
|
489
|
+
- Single pair: [["query text", "document text"]]
|
490
|
+
- Multiple pairs: [["q1", "d1"], ["q2", "d2"], ["q3", "d3"]]
|
491
|
+
|
492
|
+
is_cross_encoder: Whether to return token_type_ids for cross-encoder models.
|
493
|
+
Enables proper handling of sentence pairs with segment IDs.
|
494
|
+
|
495
|
+
Returns:
|
496
|
+
Single input cases:
|
497
|
+
Tuple[List[int], Optional[List[int]]]: (input_ids, token_type_ids)
|
498
|
+
Example: ([101, 2129, 102], [0, 0, 0]) for single text
|
499
|
+
Example: ([101, 2129, 102, 4068, 102], [0, 0, 0, 1, 1]) for cross-encoder pair
|
500
|
+
|
501
|
+
Batch input cases:
|
502
|
+
Tuple[List[List[int]], Optional[List[List[int]]]]: (batch_input_ids, batch_token_type_ids)
|
503
|
+
Example: ([[101, 2129, 102], [101, 4068, 102]], None) for regular batch
|
504
|
+
|
505
|
+
Note: token_type_ids is None unless is_cross_encoder=True.
|
506
|
+
"""
|
507
|
+
if not texts or self.tokenizer is None:
|
508
|
+
raise ValueError("texts cannot be empty and tokenizer must be initialized")
|
509
|
+
|
510
|
+
# Step 1: Detect input format and prepare for tokenization
|
511
|
+
input_format = self._detect_input_format(texts, is_cross_encoder)
|
512
|
+
tokenizer_input = self._prepare_tokenizer_input(texts, input_format)
|
513
|
+
original_batch_size = len(texts) if not isinstance(texts, str) else 1
|
514
|
+
|
515
|
+
# Step 2: Set up tokenizer arguments
|
516
|
+
tokenizer_kwargs = (
|
517
|
+
{"return_token_type_ids": is_cross_encoder} if is_cross_encoder else {}
|
518
|
+
)
|
519
|
+
|
520
|
+
# Step 3: Choose tokenization strategy
|
521
|
+
use_async_tokenizer = (
|
522
|
+
self.async_dynamic_batch_tokenizer is not None
|
523
|
+
and input_format == "single_string"
|
524
|
+
)
|
525
|
+
|
526
|
+
if use_async_tokenizer:
|
527
|
+
logger.debug("Using async dynamic batch tokenizer for single text")
|
528
|
+
result = await self.async_dynamic_batch_tokenizer.encode(
|
529
|
+
tokenizer_input[0], **tokenizer_kwargs
|
530
|
+
)
|
531
|
+
# Convert to batch format for consistency
|
532
|
+
input_ids = [result["input_ids"]]
|
533
|
+
token_type_ids = (
|
534
|
+
[result["token_type_ids"]]
|
535
|
+
if is_cross_encoder and result.get("token_type_ids")
|
536
|
+
else None
|
537
|
+
)
|
538
|
+
else:
|
539
|
+
logger.debug(f"Using regular tokenizer for {len(tokenizer_input)} inputs")
|
540
|
+
encoded = self.tokenizer(tokenizer_input, **tokenizer_kwargs)
|
541
|
+
input_ids = encoded["input_ids"]
|
542
|
+
token_type_ids = encoded.get("token_type_ids") if is_cross_encoder else None
|
543
|
+
|
544
|
+
# Step 4: Extract results based on input format
|
545
|
+
return self._extract_tokenizer_results(
|
546
|
+
input_ids, token_type_ids, input_format, original_batch_size
|
547
|
+
)
|
548
|
+
|
546
549
|
async def _tokenize_one_request(
|
547
550
|
self,
|
548
551
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
@@ -573,14 +576,10 @@ class TokenizerManager:
|
|
573
576
|
"accept text prompts. Please provide input_ids or re-initialize "
|
574
577
|
"the engine with skip_tokenizer_init=False."
|
575
578
|
)
|
576
|
-
encoded = self.tokenizer(
|
577
|
-
input_text, return_token_type_ids=is_cross_encoder_request
|
578
|
-
)
|
579
579
|
|
580
|
-
input_ids =
|
581
|
-
|
582
|
-
|
583
|
-
token_type_ids = encoded.get("token_type_ids", [None])[0]
|
580
|
+
input_ids, token_type_ids = await self._tokenize_texts(
|
581
|
+
input_text, is_cross_encoder_request
|
582
|
+
)
|
584
583
|
|
585
584
|
if self.mm_processor and obj.contains_mm_input():
|
586
585
|
if not isinstance(obj.image_data, list):
|
@@ -600,6 +599,7 @@ class TokenizerManager:
|
|
600
599
|
mm_inputs = None
|
601
600
|
|
602
601
|
self._validate_one_request(obj, input_ids)
|
602
|
+
trace_slice_end("tokenize", obj.rid)
|
603
603
|
return self._create_tokenized_object(
|
604
604
|
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
|
605
605
|
)
|
@@ -612,6 +612,7 @@ class TokenizerManager:
|
|
612
612
|
_max_req_len = self.context_len
|
613
613
|
|
614
614
|
input_token_num = len(input_ids) if input_ids is not None else 0
|
615
|
+
input_token_num += self.reserve_input_token_num
|
615
616
|
if input_token_num >= self.context_len:
|
616
617
|
if self.server_args.allow_auto_truncate:
|
617
618
|
logger.warning(
|
@@ -674,7 +675,7 @@ class TokenizerManager:
|
|
674
675
|
):
|
675
676
|
raise ValueError(
|
676
677
|
"The server is not configured to enable custom logit processor. "
|
677
|
-
"Please set `--enable-custom-
|
678
|
+
"Please set `--enable-custom-logit-processor` to enable this feature."
|
678
679
|
)
|
679
680
|
|
680
681
|
def _validate_input_ids_in_vocab(
|
@@ -713,7 +714,6 @@ class TokenizerManager:
|
|
713
714
|
)
|
714
715
|
|
715
716
|
tokenized_obj = TokenizedGenerateReqInput(
|
716
|
-
obj.rid,
|
717
717
|
input_text,
|
718
718
|
input_ids,
|
719
719
|
mm_inputs,
|
@@ -723,6 +723,7 @@ class TokenizerManager:
|
|
723
723
|
obj.top_logprobs_num,
|
724
724
|
obj.token_ids_logprob,
|
725
725
|
obj.stream,
|
726
|
+
rid=obj.rid,
|
726
727
|
bootstrap_host=obj.bootstrap_host,
|
727
728
|
bootstrap_port=obj.bootstrap_port,
|
728
729
|
bootstrap_room=obj.bootstrap_room,
|
@@ -732,15 +733,18 @@ class TokenizerManager:
|
|
732
733
|
custom_logit_processor=obj.custom_logit_processor,
|
733
734
|
return_hidden_states=obj.return_hidden_states,
|
734
735
|
data_parallel_rank=obj.data_parallel_rank,
|
736
|
+
priority=obj.priority,
|
737
|
+
extra_key=obj.extra_key,
|
735
738
|
)
|
736
739
|
elif isinstance(obj, EmbeddingReqInput):
|
737
740
|
tokenized_obj = TokenizedEmbeddingReqInput(
|
738
|
-
obj.rid,
|
739
741
|
input_text,
|
740
742
|
input_ids,
|
741
743
|
mm_inputs,
|
742
744
|
token_type_ids,
|
743
745
|
sampling_params,
|
746
|
+
rid=obj.rid,
|
747
|
+
priority=obj.priority,
|
744
748
|
)
|
745
749
|
|
746
750
|
return tokenized_obj
|
@@ -755,19 +759,30 @@ class TokenizerManager:
|
|
755
759
|
requests = [obj[i] for i in range(batch_size)]
|
756
760
|
texts = [req.text for req in requests]
|
757
761
|
|
758
|
-
#
|
759
|
-
|
760
|
-
|
762
|
+
# Check if any request is a cross-encoder request
|
763
|
+
is_cross_encoder_request = any(
|
764
|
+
isinstance(req, EmbeddingReqInput) and req.is_cross_encoder_request
|
765
|
+
for req in requests
|
766
|
+
)
|
767
|
+
|
768
|
+
# Batch tokenize all texts using unified method
|
769
|
+
input_ids_list, token_type_ids_list = await self._tokenize_texts(
|
770
|
+
texts, is_cross_encoder_request
|
771
|
+
)
|
761
772
|
|
762
773
|
# Process all requests
|
763
774
|
tokenized_objs = []
|
764
775
|
for i, req in enumerate(requests):
|
765
776
|
self._validate_one_request(obj[i], input_ids_list[i])
|
777
|
+
token_type_ids = (
|
778
|
+
token_type_ids_list[i] if token_type_ids_list is not None else None
|
779
|
+
)
|
766
780
|
tokenized_objs.append(
|
767
781
|
self._create_tokenized_object(
|
768
|
-
req, req.text, input_ids_list[i], None, None
|
782
|
+
req, req.text, input_ids_list[i], None, None, token_type_ids
|
769
783
|
)
|
770
784
|
)
|
785
|
+
trace_slice_end("tokenize", req.rid)
|
771
786
|
logger.debug(f"Completed batch processing for {batch_size} requests")
|
772
787
|
return tokenized_objs
|
773
788
|
|
@@ -795,9 +810,12 @@ class TokenizerManager:
|
|
795
810
|
tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
|
796
811
|
created_time: Optional[float] = None,
|
797
812
|
):
|
813
|
+
trace_slice_start("dispatch", obj.rid)
|
814
|
+
tokenized_obj.trace_context = trace_get_proc_propagate_context(obj.rid)
|
798
815
|
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
799
816
|
state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
|
800
817
|
self.rid_to_state[obj.rid] = state
|
818
|
+
trace_slice_end("dispatch", obj.rid, thread_finish_flag=True)
|
801
819
|
return state
|
802
820
|
|
803
821
|
def _send_batch_request(
|
@@ -1015,73 +1033,16 @@ class TokenizerManager:
|
|
1015
1033
|
except StopAsyncIteration:
|
1016
1034
|
pass
|
1017
1035
|
|
1018
|
-
async def flush_cache(self) -> FlushCacheReqOutput:
|
1019
|
-
return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
|
1020
|
-
|
1021
|
-
async def clear_hicache_storage(self) -> ClearHiCacheReqOutput:
|
1022
|
-
"""Clear the hierarchical cache storage."""
|
1023
|
-
# Delegate to the scheduler to handle HiCacheStorage clearing
|
1024
|
-
return (await self.clear_hicache_storage_communicator(ClearHiCacheReqInput()))[
|
1025
|
-
0
|
1026
|
-
]
|
1027
|
-
|
1028
1036
|
def abort_request(self, rid: str = "", abort_all: bool = False):
|
1029
1037
|
if not abort_all and rid not in self.rid_to_state:
|
1030
1038
|
return
|
1031
|
-
req = AbortReq(rid, abort_all)
|
1039
|
+
req = AbortReq(rid=rid, abort_all=abort_all)
|
1032
1040
|
self.send_to_scheduler.send_pyobj(req)
|
1033
|
-
|
1034
1041
|
if self.enable_metrics:
|
1035
|
-
|
1036
|
-
|
1037
|
-
|
1038
|
-
|
1039
|
-
output_dir: Optional[str] = None,
|
1040
|
-
start_step: Optional[int] = None,
|
1041
|
-
num_steps: Optional[int] = None,
|
1042
|
-
activities: Optional[List[str]] = None,
|
1043
|
-
with_stack: Optional[bool] = None,
|
1044
|
-
record_shapes: Optional[bool] = None,
|
1045
|
-
profile_by_stage: bool = False,
|
1046
|
-
):
|
1047
|
-
self.auto_create_handle_loop()
|
1048
|
-
env_with_stack: bool = get_bool_env_var("SGLANG_PROFILE_WITH_STACK", "true")
|
1049
|
-
with_stack = False if with_stack is False or env_with_stack is False else True
|
1050
|
-
req = ProfileReq(
|
1051
|
-
type=ProfileReqType.START_PROFILE,
|
1052
|
-
output_dir=output_dir,
|
1053
|
-
start_step=start_step,
|
1054
|
-
num_steps=num_steps,
|
1055
|
-
activities=activities,
|
1056
|
-
with_stack=with_stack,
|
1057
|
-
record_shapes=record_shapes,
|
1058
|
-
profile_by_stage=profile_by_stage,
|
1059
|
-
profile_id=str(time.time()),
|
1060
|
-
)
|
1061
|
-
return await self._execute_profile(req)
|
1062
|
-
|
1063
|
-
async def stop_profile(self):
|
1064
|
-
self.auto_create_handle_loop()
|
1065
|
-
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
|
1066
|
-
return await self._execute_profile(req)
|
1067
|
-
|
1068
|
-
async def _execute_profile(self, req: ProfileReq):
|
1069
|
-
result = (await self.profile_communicator(req))[0]
|
1070
|
-
if not result.success:
|
1071
|
-
raise RuntimeError(result.message)
|
1072
|
-
return result
|
1073
|
-
|
1074
|
-
async def start_expert_distribution_record(self):
|
1075
|
-
self.auto_create_handle_loop()
|
1076
|
-
await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
|
1077
|
-
|
1078
|
-
async def stop_expert_distribution_record(self):
|
1079
|
-
self.auto_create_handle_loop()
|
1080
|
-
await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
|
1081
|
-
|
1082
|
-
async def dump_expert_distribution_record(self):
|
1083
|
-
self.auto_create_handle_loop()
|
1084
|
-
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
|
1042
|
+
# TODO: also use custom_labels from the request
|
1043
|
+
self.metrics_collector.observe_one_aborted_request(
|
1044
|
+
self.metrics_collector.labels
|
1045
|
+
)
|
1085
1046
|
|
1086
1047
|
async def pause_generation(self):
|
1087
1048
|
async with self.is_pause_cond:
|
@@ -1118,7 +1079,7 @@ class TokenizerManager:
|
|
1118
1079
|
self, obj: UpdateWeightFromDiskReqInput
|
1119
1080
|
) -> Tuple[bool, str]:
|
1120
1081
|
if self.server_args.tokenizer_worker_num > 1:
|
1121
|
-
obj =
|
1082
|
+
obj = MultiTokenizerWrapper(self.worker_id, obj)
|
1122
1083
|
self.send_to_scheduler.send_pyobj(obj)
|
1123
1084
|
self.model_update_result = asyncio.Future()
|
1124
1085
|
if self.server_args.dp_size == 1:
|
@@ -1143,291 +1104,6 @@ class TokenizerManager:
|
|
1143
1104
|
all_paused_requests = [r.num_paused_requests for r in result]
|
1144
1105
|
return all_success, all_message, all_paused_requests
|
1145
1106
|
|
1146
|
-
async def init_weights_update_group(
|
1147
|
-
self,
|
1148
|
-
obj: InitWeightsUpdateGroupReqInput,
|
1149
|
-
request: Optional[fastapi.Request] = None,
|
1150
|
-
) -> Tuple[bool, str]:
|
1151
|
-
self.auto_create_handle_loop()
|
1152
|
-
assert (
|
1153
|
-
self.server_args.dp_size == 1
|
1154
|
-
), "dp_size must be 1 for init parameter update group"
|
1155
|
-
result = (await self.init_weights_update_group_communicator(obj))[0]
|
1156
|
-
return result.success, result.message
|
1157
|
-
|
1158
|
-
async def update_weights_from_distributed(
|
1159
|
-
self,
|
1160
|
-
obj: UpdateWeightsFromDistributedReqInput,
|
1161
|
-
request: Optional[fastapi.Request] = None,
|
1162
|
-
) -> Tuple[bool, str]:
|
1163
|
-
self.auto_create_handle_loop()
|
1164
|
-
assert (
|
1165
|
-
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
1166
|
-
), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
|
1167
|
-
|
1168
|
-
if obj.abort_all_requests:
|
1169
|
-
self.abort_request(abort_all=True)
|
1170
|
-
|
1171
|
-
# This means that weight sync
|
1172
|
-
# cannot run while requests are in progress.
|
1173
|
-
async with self.model_update_lock.writer_lock:
|
1174
|
-
result = (await self.update_weights_from_distributed_communicator(obj))[0]
|
1175
|
-
return result.success, result.message
|
1176
|
-
|
1177
|
-
async def update_weights_from_tensor(
|
1178
|
-
self,
|
1179
|
-
obj: UpdateWeightsFromTensorReqInput,
|
1180
|
-
request: Optional[fastapi.Request] = None,
|
1181
|
-
) -> Tuple[bool, str]:
|
1182
|
-
self.auto_create_handle_loop()
|
1183
|
-
assert (
|
1184
|
-
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
1185
|
-
), "dp_size must be 1 or dp attention must be enabled for update weights from tensor"
|
1186
|
-
|
1187
|
-
if obj.abort_all_requests:
|
1188
|
-
self.abort_request(abort_all=True)
|
1189
|
-
|
1190
|
-
# This means that weight sync
|
1191
|
-
# cannot run while requests are in progress.
|
1192
|
-
async with self.model_update_lock.writer_lock:
|
1193
|
-
result = (await self.update_weights_from_tensor_communicator(obj))[0]
|
1194
|
-
return result.success, result.message
|
1195
|
-
|
1196
|
-
async def load_lora_adapter(
|
1197
|
-
self,
|
1198
|
-
obj: LoadLoRAAdapterReqInput,
|
1199
|
-
_: Optional[fastapi.Request] = None,
|
1200
|
-
) -> LoadLoRAAdapterReqOutput:
|
1201
|
-
self.auto_create_handle_loop()
|
1202
|
-
|
1203
|
-
try:
|
1204
|
-
if not self.server_args.enable_lora:
|
1205
|
-
raise ValueError(
|
1206
|
-
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
1207
|
-
)
|
1208
|
-
|
1209
|
-
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
1210
|
-
# with dp_size > 1.
|
1211
|
-
assert (
|
1212
|
-
self.server_args.dp_size == 1
|
1213
|
-
), "dp_size must be 1 for dynamic lora loading"
|
1214
|
-
logger.info(
|
1215
|
-
"Start load Lora adapter. Lora name=%s, path=%s",
|
1216
|
-
obj.lora_name,
|
1217
|
-
obj.lora_path,
|
1218
|
-
)
|
1219
|
-
|
1220
|
-
async with self.lora_update_lock:
|
1221
|
-
if (
|
1222
|
-
self.server_args.max_loaded_loras is not None
|
1223
|
-
and self.lora_registry.num_registered_loras
|
1224
|
-
>= self.server_args.max_loaded_loras
|
1225
|
-
):
|
1226
|
-
raise ValueError(
|
1227
|
-
f"Cannot load LoRA adapter {obj.lora_name} at path {obj.lora_path}. "
|
1228
|
-
f"Maximum number of loaded LoRA adapters is {self.server_args.max_loaded_loras}. "
|
1229
|
-
"Please unload some LoRA adapters before loading new ones."
|
1230
|
-
)
|
1231
|
-
|
1232
|
-
# Generate new uniquely identifiable LoRARef object.
|
1233
|
-
new_adapter = LoRARef(
|
1234
|
-
lora_name=obj.lora_name,
|
1235
|
-
lora_path=obj.lora_path,
|
1236
|
-
pinned=obj.pinned,
|
1237
|
-
)
|
1238
|
-
|
1239
|
-
# Trigger the actual loading operation at the backend processes.
|
1240
|
-
obj.lora_id = new_adapter.lora_id
|
1241
|
-
result = (await self.update_lora_adapter_communicator(obj))[0]
|
1242
|
-
|
1243
|
-
# Register the LoRA adapter only after loading is successful.
|
1244
|
-
if result.success:
|
1245
|
-
await self.lora_registry.register(new_adapter)
|
1246
|
-
|
1247
|
-
return result
|
1248
|
-
except ValueError as e:
|
1249
|
-
return LoadLoRAAdapterReqOutput(
|
1250
|
-
success=False,
|
1251
|
-
error_message=str(e),
|
1252
|
-
)
|
1253
|
-
|
1254
|
-
async def unload_lora_adapter(
|
1255
|
-
self,
|
1256
|
-
obj: UnloadLoRAAdapterReqInput,
|
1257
|
-
_: Optional[fastapi.Request] = None,
|
1258
|
-
) -> UnloadLoRAAdapterReqOutput:
|
1259
|
-
self.auto_create_handle_loop()
|
1260
|
-
|
1261
|
-
try:
|
1262
|
-
if not self.server_args.enable_lora:
|
1263
|
-
raise ValueError(
|
1264
|
-
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
1265
|
-
)
|
1266
|
-
|
1267
|
-
assert (
|
1268
|
-
obj.lora_name is not None
|
1269
|
-
), "lora_name must be provided to unload LoRA adapter"
|
1270
|
-
|
1271
|
-
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
1272
|
-
# with dp_size > 1.
|
1273
|
-
assert (
|
1274
|
-
self.server_args.dp_size == 1
|
1275
|
-
), "dp_size must be 1 for dynamic lora loading"
|
1276
|
-
logger.info(
|
1277
|
-
"Start unload Lora adapter. Lora name=%s",
|
1278
|
-
obj.lora_name,
|
1279
|
-
)
|
1280
|
-
|
1281
|
-
async with self.lora_update_lock:
|
1282
|
-
# Unregister the LoRA adapter from the registry to stop new requests for this adapter
|
1283
|
-
# from being started.
|
1284
|
-
lora_id = await self.lora_registry.unregister(obj.lora_name)
|
1285
|
-
obj.lora_id = lora_id
|
1286
|
-
|
1287
|
-
# Initiate the actual unloading operation at the backend processes only after all
|
1288
|
-
# ongoing requests using this LoRA adapter are finished.
|
1289
|
-
await self.lora_registry.wait_for_unload(lora_id)
|
1290
|
-
result = (await self.update_lora_adapter_communicator(obj))[0]
|
1291
|
-
|
1292
|
-
return result
|
1293
|
-
except ValueError as e:
|
1294
|
-
return UnloadLoRAAdapterReqOutput(success=False, error_message=str(e))
|
1295
|
-
|
1296
|
-
async def get_weights_by_name(
|
1297
|
-
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
1298
|
-
):
|
1299
|
-
self.auto_create_handle_loop()
|
1300
|
-
results = await self.get_weights_by_name_communicator(obj)
|
1301
|
-
all_parameters = [r.parameter for r in results]
|
1302
|
-
if self.server_args.dp_size == 1:
|
1303
|
-
return all_parameters[0]
|
1304
|
-
else:
|
1305
|
-
return all_parameters
|
1306
|
-
|
1307
|
-
async def release_memory_occupation(
|
1308
|
-
self,
|
1309
|
-
obj: ReleaseMemoryOccupationReqInput,
|
1310
|
-
request: Optional[fastapi.Request] = None,
|
1311
|
-
):
|
1312
|
-
self.auto_create_handle_loop()
|
1313
|
-
await self.release_memory_occupation_communicator(obj)
|
1314
|
-
|
1315
|
-
async def resume_memory_occupation(
|
1316
|
-
self,
|
1317
|
-
obj: ResumeMemoryOccupationReqInput,
|
1318
|
-
request: Optional[fastapi.Request] = None,
|
1319
|
-
):
|
1320
|
-
self.auto_create_handle_loop()
|
1321
|
-
await self.resume_memory_occupation_communicator(obj)
|
1322
|
-
|
1323
|
-
async def slow_down(
|
1324
|
-
self,
|
1325
|
-
obj: SlowDownReqInput,
|
1326
|
-
request: Optional[fastapi.Request] = None,
|
1327
|
-
):
|
1328
|
-
self.auto_create_handle_loop()
|
1329
|
-
await self.slow_down_communicator(obj)
|
1330
|
-
|
1331
|
-
async def open_session(
|
1332
|
-
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
1333
|
-
):
|
1334
|
-
self.auto_create_handle_loop()
|
1335
|
-
|
1336
|
-
if obj.session_id is None:
|
1337
|
-
obj.session_id = uuid.uuid4().hex
|
1338
|
-
elif obj.session_id in self.session_futures:
|
1339
|
-
return None
|
1340
|
-
|
1341
|
-
if self.server_args.tokenizer_worker_num > 1:
|
1342
|
-
obj = MultiTokenizerWarpper(self.worker_id, obj)
|
1343
|
-
self.send_to_scheduler.send_pyobj(obj)
|
1344
|
-
|
1345
|
-
self.session_futures[obj.session_id] = asyncio.Future()
|
1346
|
-
session_id = await self.session_futures[obj.session_id]
|
1347
|
-
del self.session_futures[obj.session_id]
|
1348
|
-
return session_id
|
1349
|
-
|
1350
|
-
async def close_session(
|
1351
|
-
self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
|
1352
|
-
):
|
1353
|
-
await self.send_to_scheduler.send_pyobj(obj)
|
1354
|
-
|
1355
|
-
async def get_internal_state(self) -> List[Dict[Any, Any]]:
|
1356
|
-
req = GetInternalStateReq()
|
1357
|
-
responses: List[GetInternalStateReqOutput] = (
|
1358
|
-
await self.get_internal_state_communicator(req)
|
1359
|
-
)
|
1360
|
-
# Many DP ranks
|
1361
|
-
return [res.internal_state for res in responses]
|
1362
|
-
|
1363
|
-
async def set_internal_state(self, obj: SetInternalStateReq) -> List[bool]:
|
1364
|
-
responses: List[SetInternalStateReqOutput] = (
|
1365
|
-
await self.set_internal_state_communicator(obj)
|
1366
|
-
)
|
1367
|
-
return [res.updated for res in responses]
|
1368
|
-
|
1369
|
-
async def get_load(self) -> dict:
|
1370
|
-
# TODO(lsyin): fake load report server
|
1371
|
-
if not self.current_load_lock.locked():
|
1372
|
-
async with self.current_load_lock:
|
1373
|
-
internal_state = await self.get_internal_state()
|
1374
|
-
self.current_load = internal_state[0]["load"]
|
1375
|
-
return {"load": self.current_load}
|
1376
|
-
|
1377
|
-
def get_log_request_metadata(self):
|
1378
|
-
max_length = None
|
1379
|
-
skip_names = None
|
1380
|
-
out_skip_names = None
|
1381
|
-
if self.log_requests:
|
1382
|
-
if self.log_requests_level == 0:
|
1383
|
-
max_length = 1 << 30
|
1384
|
-
skip_names = set(
|
1385
|
-
[
|
1386
|
-
"text",
|
1387
|
-
"input_ids",
|
1388
|
-
"input_embeds",
|
1389
|
-
"image_data",
|
1390
|
-
"audio_data",
|
1391
|
-
"lora_path",
|
1392
|
-
"sampling_params",
|
1393
|
-
]
|
1394
|
-
)
|
1395
|
-
out_skip_names = set(
|
1396
|
-
[
|
1397
|
-
"text",
|
1398
|
-
"output_ids",
|
1399
|
-
"embedding",
|
1400
|
-
]
|
1401
|
-
)
|
1402
|
-
elif self.log_requests_level == 1:
|
1403
|
-
max_length = 1 << 30
|
1404
|
-
skip_names = set(
|
1405
|
-
[
|
1406
|
-
"text",
|
1407
|
-
"input_ids",
|
1408
|
-
"input_embeds",
|
1409
|
-
"image_data",
|
1410
|
-
"audio_data",
|
1411
|
-
"lora_path",
|
1412
|
-
]
|
1413
|
-
)
|
1414
|
-
out_skip_names = set(
|
1415
|
-
[
|
1416
|
-
"text",
|
1417
|
-
"output_ids",
|
1418
|
-
"embedding",
|
1419
|
-
]
|
1420
|
-
)
|
1421
|
-
elif self.log_requests_level == 2:
|
1422
|
-
max_length = 2048
|
1423
|
-
elif self.log_requests_level == 3:
|
1424
|
-
max_length = 1 << 30
|
1425
|
-
else:
|
1426
|
-
raise ValueError(
|
1427
|
-
f"Invalid --log-requests-level: {self.log_requests_level=}"
|
1428
|
-
)
|
1429
|
-
return max_length, skip_names, out_skip_names
|
1430
|
-
|
1431
1107
|
def configure_logging(self, obj: ConfigureLoggingReq):
|
1432
1108
|
if obj.log_requests is not None:
|
1433
1109
|
self.log_requests = obj.log_requests
|
@@ -1492,6 +1168,9 @@ class TokenizerManager:
|
|
1492
1168
|
self.asyncio_tasks.add(
|
1493
1169
|
loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
|
1494
1170
|
)
|
1171
|
+
self.asyncio_tasks.add(
|
1172
|
+
loop.create_task(print_exception_wrapper(self.watch_load_thread))
|
1173
|
+
)
|
1495
1174
|
|
1496
1175
|
def dump_requests_before_crash(self):
|
1497
1176
|
if self.crash_dump_performed:
|
@@ -1583,12 +1262,12 @@ class TokenizerManager:
|
|
1583
1262
|
# Drain requests
|
1584
1263
|
while True:
|
1585
1264
|
remain_num_req = len(self.rid_to_state)
|
1265
|
+
remaining_rids = list(self.rid_to_state.keys())
|
1586
1266
|
|
1587
1267
|
if self.server_status == ServerStatus.UnHealthy:
|
1588
1268
|
# if health check failed, we should exit immediately
|
1589
1269
|
logger.error(
|
1590
|
-
"Signal SIGTERM received while health check failed.
|
1591
|
-
remain_num_req,
|
1270
|
+
"Signal SIGTERM received while health check failed. Force exiting."
|
1592
1271
|
)
|
1593
1272
|
self.dump_requests_before_crash()
|
1594
1273
|
break
|
@@ -1596,13 +1275,12 @@ class TokenizerManager:
|
|
1596
1275
|
elif get_bool_env_var("SGL_FORCE_SHUTDOWN"):
|
1597
1276
|
# if force shutdown flag set, exit immediately
|
1598
1277
|
logger.error(
|
1599
|
-
"Signal SIGTERM received while force shutdown flag set. Force exiting
|
1600
|
-
remain_num_req,
|
1278
|
+
"Signal SIGTERM received while force shutdown flag set. Force exiting."
|
1601
1279
|
)
|
1602
1280
|
break
|
1603
1281
|
|
1604
1282
|
logger.info(
|
1605
|
-
f"Gracefully exiting...
|
1283
|
+
f"Gracefully exiting... Remaining number of requests {remain_num_req}. Remaining requests {remaining_rids=}."
|
1606
1284
|
)
|
1607
1285
|
if remain_num_req > 0:
|
1608
1286
|
await asyncio.sleep(5)
|
@@ -1623,7 +1301,10 @@ class TokenizerManager:
|
|
1623
1301
|
def _handle_batch_output(
|
1624
1302
|
self,
|
1625
1303
|
recv_obj: Union[
|
1626
|
-
|
1304
|
+
BatchStrOutput,
|
1305
|
+
BatchEmbeddingOutput,
|
1306
|
+
BatchMultimodalOutput,
|
1307
|
+
BatchTokenIDOutput,
|
1627
1308
|
],
|
1628
1309
|
):
|
1629
1310
|
for i, rid in enumerate(recv_obj.rids):
|
@@ -1657,7 +1338,7 @@ class TokenizerManager:
|
|
1657
1338
|
i,
|
1658
1339
|
)
|
1659
1340
|
|
1660
|
-
if not isinstance(recv_obj,
|
1341
|
+
if not isinstance(recv_obj, BatchEmbeddingOutput):
|
1661
1342
|
meta_info.update(
|
1662
1343
|
{
|
1663
1344
|
"completion_tokens": recv_obj.completion_tokens[i],
|
@@ -1668,7 +1349,7 @@ class TokenizerManager:
|
|
1668
1349
|
if getattr(recv_obj, "output_hidden_states", None):
|
1669
1350
|
meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
|
1670
1351
|
|
1671
|
-
if isinstance(recv_obj,
|
1352
|
+
if isinstance(recv_obj, BatchStrOutput):
|
1672
1353
|
state.text += recv_obj.output_strs[i]
|
1673
1354
|
if state.obj.stream:
|
1674
1355
|
state.output_ids.extend(recv_obj.output_ids[i])
|
@@ -1683,7 +1364,7 @@ class TokenizerManager:
|
|
1683
1364
|
"output_ids": output_token_ids,
|
1684
1365
|
"meta_info": meta_info,
|
1685
1366
|
}
|
1686
|
-
elif isinstance(recv_obj,
|
1367
|
+
elif isinstance(recv_obj, BatchTokenIDOutput):
|
1687
1368
|
if self.server_args.stream_output and state.obj.stream:
|
1688
1369
|
state.output_ids.extend(recv_obj.output_ids[i])
|
1689
1370
|
output_token_ids = state.output_ids[state.last_output_offset :]
|
@@ -1696,10 +1377,10 @@ class TokenizerManager:
|
|
1696
1377
|
"output_ids": output_token_ids,
|
1697
1378
|
"meta_info": meta_info,
|
1698
1379
|
}
|
1699
|
-
elif isinstance(recv_obj,
|
1380
|
+
elif isinstance(recv_obj, BatchMultimodalOutput):
|
1700
1381
|
raise NotImplementedError("BatchMultimodalOut not implemented")
|
1701
1382
|
else:
|
1702
|
-
assert isinstance(recv_obj,
|
1383
|
+
assert isinstance(recv_obj, BatchEmbeddingOutput)
|
1703
1384
|
out_dict = {
|
1704
1385
|
"embedding": recv_obj.embeddings[i],
|
1705
1386
|
"meta_info": meta_info,
|
@@ -1711,6 +1392,9 @@ class TokenizerManager:
|
|
1711
1392
|
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
1712
1393
|
state.finished_time = time.time()
|
1713
1394
|
meta_info["e2e_latency"] = state.finished_time - state.created_time
|
1395
|
+
|
1396
|
+
trace_req_finish(rid, ts=int(state.finished_time * 1e9))
|
1397
|
+
|
1714
1398
|
del self.rid_to_state[rid]
|
1715
1399
|
|
1716
1400
|
# Mark ongoing LoRA request as finished.
|
@@ -1735,7 +1419,7 @@ class TokenizerManager:
|
|
1735
1419
|
top_logprobs_num: int,
|
1736
1420
|
token_ids_logprob: List[int],
|
1737
1421
|
return_text_in_logprobs: bool,
|
1738
|
-
recv_obj:
|
1422
|
+
recv_obj: BatchStrOutput,
|
1739
1423
|
recv_obj_index: int,
|
1740
1424
|
):
|
1741
1425
|
if recv_obj.input_token_logprobs_val is None:
|
@@ -1853,13 +1537,19 @@ class TokenizerManager:
|
|
1853
1537
|
ret.append(None)
|
1854
1538
|
return ret
|
1855
1539
|
|
1856
|
-
def collect_metrics(self, state: ReqState, recv_obj:
|
1540
|
+
def collect_metrics(self, state: ReqState, recv_obj: BatchStrOutput, i: int):
|
1857
1541
|
completion_tokens = (
|
1858
1542
|
recv_obj.completion_tokens[i]
|
1859
1543
|
if getattr(recv_obj, "completion_tokens", None)
|
1860
1544
|
else 0
|
1861
1545
|
)
|
1862
1546
|
|
1547
|
+
custom_labels = getattr(state.obj, "custom_labels", None)
|
1548
|
+
labels = (
|
1549
|
+
{**self.metrics_collector.labels, **custom_labels}
|
1550
|
+
if custom_labels
|
1551
|
+
else self.metrics_collector.labels
|
1552
|
+
)
|
1863
1553
|
if (
|
1864
1554
|
state.first_token_time == 0.0
|
1865
1555
|
and self.disaggregation_mode != DisaggregationMode.PREFILL
|
@@ -1867,7 +1557,7 @@ class TokenizerManager:
|
|
1867
1557
|
state.first_token_time = state.last_time = time.time()
|
1868
1558
|
state.last_completion_tokens = completion_tokens
|
1869
1559
|
self.metrics_collector.observe_time_to_first_token(
|
1870
|
-
state.first_token_time - state.created_time
|
1560
|
+
labels, state.first_token_time - state.created_time
|
1871
1561
|
)
|
1872
1562
|
else:
|
1873
1563
|
num_new_tokens = completion_tokens - state.last_completion_tokens
|
@@ -1875,6 +1565,7 @@ class TokenizerManager:
|
|
1875
1565
|
new_time = time.time()
|
1876
1566
|
interval = new_time - state.last_time
|
1877
1567
|
self.metrics_collector.observe_inter_token_latency(
|
1568
|
+
labels,
|
1878
1569
|
interval,
|
1879
1570
|
num_new_tokens,
|
1880
1571
|
)
|
@@ -1889,6 +1580,7 @@ class TokenizerManager:
|
|
1889
1580
|
or state.obj.sampling_params.get("structural_tag", None)
|
1890
1581
|
)
|
1891
1582
|
self.metrics_collector.observe_one_finished_request(
|
1583
|
+
labels,
|
1892
1584
|
recv_obj.prompt_tokens[i],
|
1893
1585
|
completion_tokens,
|
1894
1586
|
recv_obj.cached_tokens[i],
|
@@ -1941,7 +1633,7 @@ class TokenizerManager:
|
|
1941
1633
|
|
1942
1634
|
asyncio.create_task(asyncio.to_thread(background_task))
|
1943
1635
|
|
1944
|
-
def _handle_abort_req(self, recv_obj):
|
1636
|
+
def _handle_abort_req(self, recv_obj: AbortReq):
|
1945
1637
|
if is_health_check_generate_req(recv_obj):
|
1946
1638
|
return
|
1947
1639
|
state = self.rid_to_state[recv_obj.rid]
|
@@ -2060,11 +1752,15 @@ class TokenizerManager:
|
|
2060
1752
|
# the next position after the last token in the prompt
|
2061
1753
|
output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
|
2062
1754
|
|
2063
|
-
#
|
2064
|
-
if
|
1755
|
+
# Check if output_logprobs is properly populated
|
1756
|
+
if (
|
1757
|
+
output_logprobs is None
|
1758
|
+
or not output_logprobs
|
1759
|
+
or len(output_logprobs) == 0
|
1760
|
+
):
|
2065
1761
|
raise RuntimeError(
|
2066
|
-
f"output_logprobs is
|
2067
|
-
"This
|
1762
|
+
f"output_logprobs is empty for request {result['meta_info'].get('id', '<unknown>')}. "
|
1763
|
+
"This indicates token_ids_logprobs were not computed properly for the scoring request."
|
2068
1764
|
)
|
2069
1765
|
|
2070
1766
|
for logprob, token_id, _ in output_logprobs[0]:
|
@@ -2089,6 +1785,43 @@ class TokenizerManager:
|
|
2089
1785
|
|
2090
1786
|
return scores
|
2091
1787
|
|
1788
|
+
async def watch_load_thread(self):
|
1789
|
+
# Only for dp_controller when dp_size > 1
|
1790
|
+
if (
|
1791
|
+
self.server_args.dp_size == 1
|
1792
|
+
or self.server_args.load_balance_method == "round_robin"
|
1793
|
+
):
|
1794
|
+
return
|
1795
|
+
|
1796
|
+
while True:
|
1797
|
+
await asyncio.sleep(self.server_args.load_watch_interval)
|
1798
|
+
loads = await self.get_load_communicator(GetLoadReqInput())
|
1799
|
+
load_udpate_req = WatchLoadUpdateReq(loads=loads)
|
1800
|
+
self.send_to_scheduler.send_pyobj(load_udpate_req)
|
1801
|
+
|
1802
|
+
def _trace_request_start(
|
1803
|
+
self,
|
1804
|
+
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
1805
|
+
created_time: Optional[float] = None,
|
1806
|
+
):
|
1807
|
+
if obj.is_single:
|
1808
|
+
bootstrap_room = (
|
1809
|
+
obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None
|
1810
|
+
)
|
1811
|
+
trace_req_start(obj.rid, bootstrap_room, ts=int(created_time * 1e9))
|
1812
|
+
trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True)
|
1813
|
+
else:
|
1814
|
+
for i in range(len(obj.rid)):
|
1815
|
+
bootstrap_room = (
|
1816
|
+
obj.bootstrap_room[i]
|
1817
|
+
if hasattr(obj, "bootstrap_room") and obj.bootstrap_room
|
1818
|
+
else None
|
1819
|
+
)
|
1820
|
+
trace_req_start(obj.rid[i], bootstrap_room, ts=int(created_time * 1e9))
|
1821
|
+
trace_slice_start(
|
1822
|
+
"", obj.rid[i], ts=int(created_time * 1e9), anonymous=True
|
1823
|
+
)
|
1824
|
+
|
2092
1825
|
|
2093
1826
|
class ServerStatus(Enum):
|
2094
1827
|
Up = "Up"
|
@@ -2134,57 +1867,12 @@ class SignalHandler:
|
|
2134
1867
|
|
2135
1868
|
def running_phase_sigquit_handler(self, signum=None, frame=None):
|
2136
1869
|
logger.error(
|
2137
|
-
"
|
1870
|
+
f"SIGQUIT received. {signum=}, {frame=}. It usually means one child failed."
|
2138
1871
|
)
|
2139
1872
|
self.tokenizer_manager.dump_requests_before_crash()
|
2140
1873
|
kill_process_tree(os.getpid())
|
2141
1874
|
|
2142
1875
|
|
2143
|
-
T = TypeVar("T")
|
2144
|
-
|
2145
|
-
|
2146
|
-
class _Communicator(Generic[T]):
|
2147
|
-
"""Note: The communicator now only run up to 1 in-flight request at any time."""
|
2148
|
-
|
2149
|
-
enable_multi_tokenizer = False
|
2150
|
-
|
2151
|
-
def __init__(self, sender, fan_out: int):
|
2152
|
-
self._sender = sender
|
2153
|
-
self._fan_out = fan_out
|
2154
|
-
self._result_event: Optional[asyncio.Event] = None
|
2155
|
-
self._result_values: Optional[List[T]] = None
|
2156
|
-
self._ready_queue: Deque[asyncio.Future] = deque()
|
2157
|
-
|
2158
|
-
async def __call__(self, obj):
|
2159
|
-
ready_event = asyncio.Event()
|
2160
|
-
if self._result_event is not None or len(self._ready_queue) > 0:
|
2161
|
-
self._ready_queue.append(ready_event)
|
2162
|
-
await ready_event.wait()
|
2163
|
-
assert self._result_event is None
|
2164
|
-
assert self._result_values is None
|
2165
|
-
|
2166
|
-
if obj:
|
2167
|
-
if _Communicator.enable_multi_tokenizer:
|
2168
|
-
obj = MultiTokenizerWarpper(worker_id=os.getpid(), obj=obj)
|
2169
|
-
self._sender.send_pyobj(obj)
|
2170
|
-
|
2171
|
-
self._result_event = asyncio.Event()
|
2172
|
-
self._result_values = []
|
2173
|
-
await self._result_event.wait()
|
2174
|
-
result_values = self._result_values
|
2175
|
-
self._result_event = self._result_values = None
|
2176
|
-
|
2177
|
-
if len(self._ready_queue) > 0:
|
2178
|
-
self._ready_queue.popleft().set()
|
2179
|
-
|
2180
|
-
return result_values
|
2181
|
-
|
2182
|
-
def handle_recv(self, recv_obj: T):
|
2183
|
-
self._result_values.append(recv_obj)
|
2184
|
-
if len(self._result_values) == self._fan_out:
|
2185
|
-
self._result_event.set()
|
2186
|
-
|
2187
|
-
|
2188
1876
|
# Note: request abort handling logic
|
2189
1877
|
# We should handle all of the following cases correctly.
|
2190
1878
|
#
|