sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,855 @@
|
|
1
|
+
"""
|
2
|
+
gRPC Request Manager - Orchestrates request lifecycle without tokenization.
|
3
|
+
Mimics TokenizerManager's state management and ZMQ communication patterns.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import asyncio
|
7
|
+
import copy
|
8
|
+
import dataclasses
|
9
|
+
import logging
|
10
|
+
import os
|
11
|
+
import signal
|
12
|
+
import sys
|
13
|
+
import threading
|
14
|
+
import time
|
15
|
+
import uuid
|
16
|
+
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
17
|
+
|
18
|
+
import grpc
|
19
|
+
import zmq
|
20
|
+
import zmq.asyncio
|
21
|
+
|
22
|
+
from sglang.srt.managers.io_struct import (
|
23
|
+
AbortReq,
|
24
|
+
BatchEmbeddingOutput,
|
25
|
+
BatchTokenIDOutput,
|
26
|
+
HealthCheckOutput,
|
27
|
+
TokenizedEmbeddingReqInput,
|
28
|
+
TokenizedGenerateReqInput,
|
29
|
+
)
|
30
|
+
from sglang.srt.server_args import PortArgs, ServerArgs
|
31
|
+
from sglang.srt.utils import get_zmq_socket, kill_process_tree
|
32
|
+
from sglang.utils import get_exception_traceback
|
33
|
+
|
34
|
+
logger = logging.getLogger(__name__)
|
35
|
+
|
36
|
+
|
37
|
+
class GrpcSignalHandler:
|
38
|
+
"""Minimal signal handler for gRPC server - delegates real crash handling to scheduler."""
|
39
|
+
|
40
|
+
def __init__(self, grpc_manager):
|
41
|
+
self.grpc_manager = grpc_manager
|
42
|
+
|
43
|
+
def sigterm_handler(self, signum=None, frame=None):
|
44
|
+
"""Handle SIGTERM by gracefully shutting down gRPC server."""
|
45
|
+
logger.warning(
|
46
|
+
f"SIGTERM received. {signum=} {frame=}. Shutting down gRPC server..."
|
47
|
+
)
|
48
|
+
self.grpc_manager.gracefully_exit = True
|
49
|
+
|
50
|
+
def running_phase_sigquit_handler(self, signum=None, frame=None):
|
51
|
+
"""Handle SIGQUIT from failed scheduler process."""
|
52
|
+
logger.error(
|
53
|
+
"Received SIGQUIT from scheduler process. Scheduler failed, shutting down gRPC server."
|
54
|
+
)
|
55
|
+
logger.info(
|
56
|
+
"Note: Crash dumps are handled by the scheduler process, not the gRPC server."
|
57
|
+
)
|
58
|
+
# Just exit cleanly - the scheduler handles crash dumps
|
59
|
+
kill_process_tree(os.getpid(), include_parent=True)
|
60
|
+
|
61
|
+
|
62
|
+
@dataclasses.dataclass
|
63
|
+
class GrpcReqState:
|
64
|
+
"""State tracking for a gRPC request."""
|
65
|
+
|
66
|
+
# Request identification
|
67
|
+
request_id: str
|
68
|
+
grpc_context: Optional[grpc.aio.ServicerContext]
|
69
|
+
|
70
|
+
# Communication
|
71
|
+
out_queue: asyncio.Queue
|
72
|
+
finished: bool
|
73
|
+
event: asyncio.Event
|
74
|
+
obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]
|
75
|
+
|
76
|
+
# Metrics (same as TokenizerManager's ReqState)
|
77
|
+
created_time: float
|
78
|
+
finished_time: float = 0.0
|
79
|
+
first_token_time: float = 0.0
|
80
|
+
last_time: float = 0.0
|
81
|
+
last_completion_tokens: int = 1
|
82
|
+
|
83
|
+
# Streaming state
|
84
|
+
stream_finished: bool = False
|
85
|
+
input_logprobs_sent: bool = False # Track if input logprobs were sent in streaming
|
86
|
+
|
87
|
+
# Token accumulation (for non-streaming)
|
88
|
+
output_ids: List[int] = dataclasses.field(default_factory=list)
|
89
|
+
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
|
90
|
+
input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
|
91
|
+
output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
|
92
|
+
output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
|
93
|
+
input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
|
94
|
+
input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
|
95
|
+
output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
|
96
|
+
output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
|
97
|
+
|
98
|
+
# Session state
|
99
|
+
session_id: Optional[str] = None
|
100
|
+
is_session_request: bool = False
|
101
|
+
|
102
|
+
|
103
|
+
class GrpcRequestManager:
|
104
|
+
"""
|
105
|
+
Manages gRPC request lifecycle, mimicking TokenizerManager's orchestration
|
106
|
+
behaviors without tokenization.
|
107
|
+
"""
|
108
|
+
|
109
|
+
def __init__(
|
110
|
+
self,
|
111
|
+
server_args: ServerArgs,
|
112
|
+
port_args: PortArgs,
|
113
|
+
bootstrap_server=None,
|
114
|
+
):
|
115
|
+
"""Initialize the gRPC request manager."""
|
116
|
+
self.server_args = server_args
|
117
|
+
self.port_args = port_args
|
118
|
+
|
119
|
+
# ZMQ Communication Setup (same pattern as TokenizerManager)
|
120
|
+
self.context = zmq.asyncio.Context(2)
|
121
|
+
|
122
|
+
# Socket for receiving outputs from scheduler
|
123
|
+
self.recv_from_scheduler = get_zmq_socket(
|
124
|
+
self.context, zmq.PULL, port_args.detokenizer_ipc_name, bind=True
|
125
|
+
)
|
126
|
+
|
127
|
+
# Socket for sending requests to scheduler
|
128
|
+
self.send_to_scheduler = get_zmq_socket(
|
129
|
+
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name, bind=True
|
130
|
+
)
|
131
|
+
|
132
|
+
# State Management (from TokenizerManager)
|
133
|
+
self.rid_to_state: Dict[str, GrpcReqState] = {}
|
134
|
+
self.asyncio_tasks: set = set()
|
135
|
+
self.gracefully_exit = False
|
136
|
+
self.no_create_loop = False
|
137
|
+
self.event_loop = None
|
138
|
+
|
139
|
+
# Pause/Resume Control
|
140
|
+
self.is_pause = False
|
141
|
+
self.is_pause_cond = asyncio.Condition()
|
142
|
+
|
143
|
+
# Metrics
|
144
|
+
self.last_receive_tstamp = time.time()
|
145
|
+
|
146
|
+
# Crash dump for debugging
|
147
|
+
self.crash_dump_request_list = []
|
148
|
+
self.crash_dump_performed = False
|
149
|
+
|
150
|
+
# Bootstrap server (passed from serve_grpc, not started here)
|
151
|
+
self.bootstrap_server = bootstrap_server
|
152
|
+
|
153
|
+
logger.info(
|
154
|
+
f"GrpcRequestManager initialized with ZMQ IPC: "
|
155
|
+
f"recv={port_args.detokenizer_ipc_name}, "
|
156
|
+
f"send={port_args.scheduler_input_ipc_name}"
|
157
|
+
)
|
158
|
+
if self.bootstrap_server:
|
159
|
+
logger.info(
|
160
|
+
f"Bootstrap server initialized for disaggregation mode: "
|
161
|
+
f"{server_args.disaggregation_mode}"
|
162
|
+
)
|
163
|
+
|
164
|
+
async def generate_request(
|
165
|
+
self,
|
166
|
+
obj: TokenizedGenerateReqInput,
|
167
|
+
request_id: Optional[str] = None,
|
168
|
+
grpc_context: Optional[grpc.aio.ServicerContext] = None,
|
169
|
+
) -> AsyncGenerator[Union[Dict, List[Dict]], None]:
|
170
|
+
"""
|
171
|
+
Submit a generation request to the scheduler with n>1 parallel sampling support.
|
172
|
+
|
173
|
+
This method implements the same two-phase approach as tokenizer_manager.py:
|
174
|
+
1. Phase 1: Send prefix caching request (max_new_tokens=0)
|
175
|
+
2. Phase 2: Send n generation requests that reuse the cached prefix
|
176
|
+
|
177
|
+
Yields individual responses for streaming, or aggregated responses for non-streaming.
|
178
|
+
"""
|
179
|
+
n = getattr(obj.sampling_params, "n", 1)
|
180
|
+
|
181
|
+
if n <= 1:
|
182
|
+
async for response in self._handle_single_request(
|
183
|
+
obj, request_id, grpc_context
|
184
|
+
):
|
185
|
+
yield response
|
186
|
+
return
|
187
|
+
|
188
|
+
# N>1 handling - two-phase approach
|
189
|
+
logger.debug(f"Multiple sampling request (n={n}), using two-phase approach")
|
190
|
+
|
191
|
+
# Generate base request ID if not provided
|
192
|
+
if request_id is None:
|
193
|
+
base_request_id = f"grpc-{uuid.uuid4().hex}"
|
194
|
+
else:
|
195
|
+
base_request_id = request_id
|
196
|
+
|
197
|
+
# Phase 1: Cache the common prefix
|
198
|
+
logger.debug(f"Phase 1: Caching prefix for request {base_request_id}")
|
199
|
+
prefix_obj = copy.copy(obj)
|
200
|
+
prefix_obj.sampling_params = copy.copy(obj.sampling_params)
|
201
|
+
prefix_obj.sampling_params.max_new_tokens = 0 # Prefill-only
|
202
|
+
prefix_obj.sampling_params.n = 1 # Don't replicate prefix request
|
203
|
+
|
204
|
+
# Send prefix caching request and consume response
|
205
|
+
async for _ in self._handle_single_request(
|
206
|
+
prefix_obj, f"{base_request_id}-prefix", grpc_context
|
207
|
+
):
|
208
|
+
# Consume prefix response (usually just one chunk with finish_reason)
|
209
|
+
pass
|
210
|
+
|
211
|
+
logger.debug(f"Phase 1 completed: Prefix cached for {base_request_id}")
|
212
|
+
|
213
|
+
# Phase 2: Generate n parallel requests
|
214
|
+
logger.debug(f"Phase 2: Generating {n} parallel requests")
|
215
|
+
generators = []
|
216
|
+
request_ids = []
|
217
|
+
|
218
|
+
for i in range(n):
|
219
|
+
# Create individual generation request
|
220
|
+
gen_obj = copy.copy(obj)
|
221
|
+
gen_obj.sampling_params = copy.copy(obj.sampling_params)
|
222
|
+
gen_obj.sampling_params.n = 1 # Each request generates 1 response
|
223
|
+
|
224
|
+
gen_request_id = f"{base_request_id}-{i}"
|
225
|
+
request_ids.append(gen_request_id)
|
226
|
+
|
227
|
+
# Start generation request
|
228
|
+
generators.append(
|
229
|
+
self._handle_single_request(gen_obj, gen_request_id, grpc_context)
|
230
|
+
)
|
231
|
+
|
232
|
+
# Handle response aggregation
|
233
|
+
is_stream = getattr(obj, "stream", False)
|
234
|
+
|
235
|
+
if not is_stream:
|
236
|
+
# Non-streaming: collect all responses and return as batch
|
237
|
+
logger.debug(f"Non-streaming mode: collecting {n} responses")
|
238
|
+
responses = []
|
239
|
+
for generator in generators:
|
240
|
+
async for response in generator:
|
241
|
+
responses.append(response)
|
242
|
+
yield responses # Return all responses as a batch
|
243
|
+
else:
|
244
|
+
# Streaming mode: multiplex responses with index for ordering
|
245
|
+
logger.debug(f"Streaming mode: multiplexing {n} streams")
|
246
|
+
rid_to_index = {rid: i for i, rid in enumerate(request_ids)}
|
247
|
+
|
248
|
+
# Create async tasks for all generators
|
249
|
+
task_map = {}
|
250
|
+
for generator in generators:
|
251
|
+
task = asyncio.create_task(generator.__anext__())
|
252
|
+
task_map[task] = generator
|
253
|
+
|
254
|
+
# Process responses as they arrive
|
255
|
+
while task_map:
|
256
|
+
done, _ = await asyncio.wait(
|
257
|
+
task_map.keys(), return_when=asyncio.FIRST_COMPLETED
|
258
|
+
)
|
259
|
+
|
260
|
+
for task in done:
|
261
|
+
generator = task_map.pop(task)
|
262
|
+
try:
|
263
|
+
response = await task
|
264
|
+
|
265
|
+
# Add index for client-side ordering
|
266
|
+
if isinstance(response, dict) and "meta_info" in response:
|
267
|
+
response_rid = response["meta_info"].get("id", "")
|
268
|
+
if response_rid in rid_to_index:
|
269
|
+
response["index"] = rid_to_index[response_rid]
|
270
|
+
|
271
|
+
yield response
|
272
|
+
|
273
|
+
# Create next task for this generator
|
274
|
+
next_task = asyncio.create_task(generator.__anext__())
|
275
|
+
task_map[next_task] = generator
|
276
|
+
|
277
|
+
except StopAsyncIteration:
|
278
|
+
# This generator is finished
|
279
|
+
pass
|
280
|
+
|
281
|
+
async def _handle_single_request(
|
282
|
+
self,
|
283
|
+
obj: TokenizedGenerateReqInput,
|
284
|
+
request_id: Optional[str] = None,
|
285
|
+
grpc_context: Optional[grpc.aio.ServicerContext] = None,
|
286
|
+
):
|
287
|
+
"""Handle a single request - core implementation without n>1 logic."""
|
288
|
+
# Generate request ID if not provided
|
289
|
+
if request_id is None:
|
290
|
+
request_id = f"grpc-{uuid.uuid4().hex}"
|
291
|
+
|
292
|
+
obj.rid = request_id
|
293
|
+
|
294
|
+
# Create and register request state
|
295
|
+
# TODO: support log_request
|
296
|
+
state = GrpcReqState(
|
297
|
+
request_id=request_id,
|
298
|
+
grpc_context=grpc_context,
|
299
|
+
out_queue=asyncio.Queue(),
|
300
|
+
finished=False,
|
301
|
+
event=asyncio.Event(),
|
302
|
+
obj=obj,
|
303
|
+
created_time=time.time(),
|
304
|
+
)
|
305
|
+
|
306
|
+
# Track session if needed
|
307
|
+
if hasattr(obj, "session_params") and obj.session_params:
|
308
|
+
state.session_id = obj.session_params.session_id
|
309
|
+
state.is_session_request = True
|
310
|
+
|
311
|
+
self.rid_to_state[request_id] = state
|
312
|
+
self.record_request_for_crash_dump(obj)
|
313
|
+
|
314
|
+
try:
|
315
|
+
# Send to scheduler - let exceptions bubble up to grpc_server.py
|
316
|
+
await self._send_to_scheduler(obj)
|
317
|
+
|
318
|
+
is_stream = getattr(obj, "stream", False)
|
319
|
+
|
320
|
+
while True:
|
321
|
+
# Client cancelled - notify scheduler and exit
|
322
|
+
if grpc_context and grpc_context.cancelled():
|
323
|
+
await self.abort_request(request_id)
|
324
|
+
return
|
325
|
+
|
326
|
+
try:
|
327
|
+
response = await asyncio.wait_for(state.out_queue.get(), timeout=4)
|
328
|
+
|
329
|
+
if is_stream:
|
330
|
+
yield response
|
331
|
+
|
332
|
+
# Non-streaming: yield final response with accumulated tokens from state
|
333
|
+
if isinstance(response, dict) and response.get("finished", False):
|
334
|
+
if not is_stream:
|
335
|
+
final_response = response.copy()
|
336
|
+
final_response["token_ids"] = state.output_ids
|
337
|
+
yield final_response
|
338
|
+
break
|
339
|
+
|
340
|
+
except asyncio.TimeoutError:
|
341
|
+
# Timeout waiting for response - abort and cleanup
|
342
|
+
logger.warning(
|
343
|
+
f"Timeout waiting for response for request {request_id}"
|
344
|
+
)
|
345
|
+
await self.abort_request(request_id)
|
346
|
+
return
|
347
|
+
|
348
|
+
finally:
|
349
|
+
# Always clean up request state when exiting
|
350
|
+
self._cleanup_request_state(request_id)
|
351
|
+
|
352
|
+
def _cleanup_request_state(self, request_id: str):
|
353
|
+
"""Clean up local request state (does not notify scheduler)."""
|
354
|
+
if request_id in self.rid_to_state:
|
355
|
+
del self.rid_to_state[request_id]
|
356
|
+
|
357
|
+
async def embedding_request(
|
358
|
+
self,
|
359
|
+
obj: TokenizedEmbeddingReqInput,
|
360
|
+
request_id: Optional[str] = None,
|
361
|
+
) -> asyncio.Future:
|
362
|
+
"""
|
363
|
+
Submit an embedding request to the scheduler.
|
364
|
+
Returns a future that will contain the embedding result.
|
365
|
+
"""
|
366
|
+
# Generate request ID if not provided
|
367
|
+
if request_id is None:
|
368
|
+
request_id = f"grpc-embed-{uuid.uuid4().hex}"
|
369
|
+
|
370
|
+
obj.rid = request_id
|
371
|
+
|
372
|
+
# Create request state
|
373
|
+
state = GrpcReqState(
|
374
|
+
request_id=request_id,
|
375
|
+
grpc_context=None,
|
376
|
+
out_queue=asyncio.Queue(),
|
377
|
+
finished=False,
|
378
|
+
event=asyncio.Event(),
|
379
|
+
obj=obj,
|
380
|
+
created_time=time.time(),
|
381
|
+
)
|
382
|
+
|
383
|
+
# Register state
|
384
|
+
self.rid_to_state[request_id] = state
|
385
|
+
|
386
|
+
# Create future for result
|
387
|
+
future = asyncio.Future()
|
388
|
+
|
389
|
+
# Send to scheduler
|
390
|
+
try:
|
391
|
+
await self._send_to_scheduler(obj)
|
392
|
+
except Exception as e:
|
393
|
+
del self.rid_to_state[request_id]
|
394
|
+
future.set_exception(e)
|
395
|
+
return future
|
396
|
+
|
397
|
+
# Wait for result in background
|
398
|
+
async def wait_for_result():
|
399
|
+
try:
|
400
|
+
# Wait for completion
|
401
|
+
await state.event.wait()
|
402
|
+
# Get result from queue
|
403
|
+
result = await state.out_queue.get()
|
404
|
+
future.set_result(result)
|
405
|
+
except Exception as e:
|
406
|
+
future.set_exception(e)
|
407
|
+
finally:
|
408
|
+
# Clean up
|
409
|
+
if request_id in self.rid_to_state:
|
410
|
+
del self.rid_to_state[request_id]
|
411
|
+
|
412
|
+
asyncio.create_task(wait_for_result())
|
413
|
+
return future
|
414
|
+
|
415
|
+
async def abort_request(self, request_id: str) -> bool:
|
416
|
+
"""Abort a running request."""
|
417
|
+
if request_id not in self.rid_to_state:
|
418
|
+
return False
|
419
|
+
|
420
|
+
# Send abort to scheduler
|
421
|
+
abort_req = AbortReq(rid=request_id)
|
422
|
+
try:
|
423
|
+
await self._send_to_scheduler(abort_req)
|
424
|
+
except Exception as e:
|
425
|
+
logger.error(f"Failed to send abort request: {e}")
|
426
|
+
return False
|
427
|
+
|
428
|
+
# Mark as finished
|
429
|
+
state = self.rid_to_state.get(request_id)
|
430
|
+
if state:
|
431
|
+
state.finished = True
|
432
|
+
state.stream_finished = True
|
433
|
+
state.event.set()
|
434
|
+
|
435
|
+
# Send abort notification to output queue
|
436
|
+
await state.out_queue.put({"error": "Request aborted", "abort": True})
|
437
|
+
|
438
|
+
return True
|
439
|
+
|
440
|
+
async def pause_generation(self):
|
441
|
+
"""Pause generation processing."""
|
442
|
+
async with self.is_pause_cond:
|
443
|
+
self.is_pause = True
|
444
|
+
logger.info("Generation paused")
|
445
|
+
|
446
|
+
async def resume_generation(self):
|
447
|
+
"""Resume generation processing."""
|
448
|
+
async with self.is_pause_cond:
|
449
|
+
self.is_pause = False
|
450
|
+
self.is_pause_cond.notify_all()
|
451
|
+
logger.info("Generation resumed")
|
452
|
+
|
453
|
+
async def handle_loop(self):
|
454
|
+
"""
|
455
|
+
Main event loop - processes outputs from scheduler.
|
456
|
+
Mimics TokenizerManager's handle_loop.
|
457
|
+
"""
|
458
|
+
while not self.gracefully_exit:
|
459
|
+
try:
|
460
|
+
# Receive from scheduler
|
461
|
+
recv_obj = await self.recv_from_scheduler.recv_pyobj()
|
462
|
+
self.last_receive_tstamp = time.time()
|
463
|
+
|
464
|
+
# Check for pause
|
465
|
+
async with self.is_pause_cond:
|
466
|
+
while self.is_pause:
|
467
|
+
await self.is_pause_cond.wait()
|
468
|
+
|
469
|
+
# Handle different output types
|
470
|
+
if isinstance(recv_obj, BatchTokenIDOutput):
|
471
|
+
await self._handle_batch_output(recv_obj)
|
472
|
+
elif isinstance(recv_obj, BatchEmbeddingOutput):
|
473
|
+
await self._handle_embedding_output(recv_obj)
|
474
|
+
elif isinstance(recv_obj, HealthCheckOutput):
|
475
|
+
await self._handle_health_check_output(recv_obj)
|
476
|
+
else:
|
477
|
+
logger.warning(f"Unknown output type: {type(recv_obj)}")
|
478
|
+
|
479
|
+
except zmq.error.Again:
|
480
|
+
# Timeout, check if we should exit
|
481
|
+
if self.gracefully_exit:
|
482
|
+
break
|
483
|
+
continue
|
484
|
+
except zmq.error.ZMQError as e:
|
485
|
+
# Socket closed or other ZMQ error - exit cleanly if shutting down
|
486
|
+
if self.gracefully_exit:
|
487
|
+
logger.debug(f"ZMQ recv interrupted during shutdown: {e}")
|
488
|
+
break
|
489
|
+
logger.error(
|
490
|
+
f"ZMQ error in handle loop: {e}\n{get_exception_traceback()}"
|
491
|
+
)
|
492
|
+
break
|
493
|
+
except Exception as e:
|
494
|
+
logger.error(f"Handle loop error: {e}\n{get_exception_traceback()}")
|
495
|
+
if self.gracefully_exit:
|
496
|
+
break
|
497
|
+
|
498
|
+
def _convert_logprob_style(
|
499
|
+
self,
|
500
|
+
state: GrpcReqState,
|
501
|
+
batch_out: BatchTokenIDOutput,
|
502
|
+
batch_index: int,
|
503
|
+
):
|
504
|
+
"""
|
505
|
+
Convert and accumulate logprobs from batch output to state.
|
506
|
+
Follows the same logic as tokenizer_manager.convert_logprob_style.
|
507
|
+
"""
|
508
|
+
# Early exit if no input logprobs at all
|
509
|
+
if batch_out.input_token_logprobs_val is None:
|
510
|
+
return
|
511
|
+
|
512
|
+
# Accumulate input token logprobs (only if list is non-empty)
|
513
|
+
if len(batch_out.input_token_logprobs_val) > 0:
|
514
|
+
state.input_token_logprobs_val.extend(
|
515
|
+
batch_out.input_token_logprobs_val[batch_index]
|
516
|
+
)
|
517
|
+
state.input_token_logprobs_idx.extend(
|
518
|
+
batch_out.input_token_logprobs_idx[batch_index]
|
519
|
+
)
|
520
|
+
|
521
|
+
# Always accumulate output token logprobs
|
522
|
+
state.output_token_logprobs_val.extend(
|
523
|
+
batch_out.output_token_logprobs_val[batch_index]
|
524
|
+
)
|
525
|
+
state.output_token_logprobs_idx.extend(
|
526
|
+
batch_out.output_token_logprobs_idx[batch_index]
|
527
|
+
)
|
528
|
+
|
529
|
+
# Handle top logprobs if requested
|
530
|
+
if state.obj.top_logprobs_num > 0:
|
531
|
+
# Accumulate input top logprobs (only if list is non-empty)
|
532
|
+
if len(batch_out.input_top_logprobs_val) > 0:
|
533
|
+
state.input_top_logprobs_val.extend(
|
534
|
+
batch_out.input_top_logprobs_val[batch_index]
|
535
|
+
)
|
536
|
+
state.input_top_logprobs_idx.extend(
|
537
|
+
batch_out.input_top_logprobs_idx[batch_index]
|
538
|
+
)
|
539
|
+
|
540
|
+
# Always accumulate output top logprobs
|
541
|
+
state.output_top_logprobs_val.extend(
|
542
|
+
batch_out.output_top_logprobs_val[batch_index]
|
543
|
+
)
|
544
|
+
state.output_top_logprobs_idx.extend(
|
545
|
+
batch_out.output_top_logprobs_idx[batch_index]
|
546
|
+
)
|
547
|
+
|
548
|
+
async def _handle_batch_output(self, batch_out: BatchTokenIDOutput):
|
549
|
+
"""Handle batch generation output from scheduler."""
|
550
|
+
# Process each request in the batch
|
551
|
+
for i, rid in enumerate(batch_out.rids):
|
552
|
+
if rid not in self.rid_to_state:
|
553
|
+
continue
|
554
|
+
|
555
|
+
state = self.rid_to_state[rid]
|
556
|
+
|
557
|
+
# Update metrics
|
558
|
+
now = time.time()
|
559
|
+
if state.first_token_time == 0.0:
|
560
|
+
state.first_token_time = now
|
561
|
+
state.last_time = now
|
562
|
+
|
563
|
+
# Extract output for this request
|
564
|
+
output_data = {
|
565
|
+
"request_id": rid,
|
566
|
+
"token_ids": batch_out.output_ids[i] if batch_out.output_ids else [],
|
567
|
+
"finished": batch_out.finished_reasons[i] is not None,
|
568
|
+
"meta_info": {
|
569
|
+
"prompt_tokens": (
|
570
|
+
batch_out.prompt_tokens[i] if batch_out.prompt_tokens else 0
|
571
|
+
),
|
572
|
+
"completion_tokens": (
|
573
|
+
batch_out.completion_tokens[i]
|
574
|
+
if batch_out.completion_tokens
|
575
|
+
else 0
|
576
|
+
),
|
577
|
+
"cached_tokens": (
|
578
|
+
batch_out.cached_tokens[i] if batch_out.cached_tokens else 0
|
579
|
+
),
|
580
|
+
"finish_reason": (
|
581
|
+
batch_out.finished_reasons[i]
|
582
|
+
if batch_out.finished_reasons[i]
|
583
|
+
else None
|
584
|
+
),
|
585
|
+
},
|
586
|
+
}
|
587
|
+
|
588
|
+
# Accumulate logprobs (following tokenizer_manager pattern)
|
589
|
+
if state.obj.return_logprob:
|
590
|
+
self._convert_logprob_style(state, batch_out, i)
|
591
|
+
|
592
|
+
# Send input logprobs based if available
|
593
|
+
if (
|
594
|
+
state.obj.return_logprob
|
595
|
+
and state.obj.logprob_start_len >= 0
|
596
|
+
and state.input_token_logprobs_val
|
597
|
+
):
|
598
|
+
if state.obj.stream and not state.input_logprobs_sent:
|
599
|
+
# Streaming: send input logprobs once in first chunk that has them
|
600
|
+
output_data["input_logprobs"] = {
|
601
|
+
"token_logprobs_val": state.input_token_logprobs_val,
|
602
|
+
"token_logprobs_idx": state.input_token_logprobs_idx,
|
603
|
+
"top_logprobs_val": state.input_top_logprobs_val,
|
604
|
+
"top_logprobs_idx": state.input_top_logprobs_idx,
|
605
|
+
}
|
606
|
+
state.input_logprobs_sent = True
|
607
|
+
elif not state.obj.stream and output_data["finished"]:
|
608
|
+
# Non-streaming: send input logprobs in final chunk
|
609
|
+
output_data["input_logprobs"] = {
|
610
|
+
"token_logprobs_val": state.input_token_logprobs_val,
|
611
|
+
"token_logprobs_idx": state.input_token_logprobs_idx,
|
612
|
+
"top_logprobs_val": state.input_top_logprobs_val,
|
613
|
+
"top_logprobs_idx": state.input_top_logprobs_idx,
|
614
|
+
}
|
615
|
+
|
616
|
+
# Send output logprobs if available
|
617
|
+
if (
|
618
|
+
state.obj.return_logprob
|
619
|
+
and batch_out.output_token_logprobs_val
|
620
|
+
and i < len(batch_out.output_token_logprobs_val)
|
621
|
+
):
|
622
|
+
if state.obj.stream:
|
623
|
+
# For streaming: send incremental logprobs (only new tokens in this chunk)
|
624
|
+
# NOTE: this is different than TokenizerManager, which always accumulates
|
625
|
+
def get_part(attr_name):
|
626
|
+
source_list = getattr(batch_out, attr_name, None)
|
627
|
+
return (
|
628
|
+
source_list[i]
|
629
|
+
if source_list and i < len(source_list)
|
630
|
+
else []
|
631
|
+
)
|
632
|
+
|
633
|
+
output_data["output_logprobs"] = {
|
634
|
+
"token_logprobs_val": batch_out.output_token_logprobs_val[i],
|
635
|
+
"token_logprobs_idx": get_part("output_token_logprobs_idx"),
|
636
|
+
"top_logprobs_val": get_part("output_top_logprobs_val"),
|
637
|
+
"top_logprobs_idx": get_part("output_top_logprobs_idx"),
|
638
|
+
}
|
639
|
+
elif output_data["finished"]:
|
640
|
+
# Non-streaming: send cumulative output logprobs in final chunk
|
641
|
+
output_data["output_logprobs"] = {
|
642
|
+
"token_logprobs_val": state.output_token_logprobs_val,
|
643
|
+
"token_logprobs_idx": state.output_token_logprobs_idx,
|
644
|
+
"top_logprobs_val": state.output_top_logprobs_val,
|
645
|
+
"top_logprobs_idx": state.output_top_logprobs_idx,
|
646
|
+
}
|
647
|
+
|
648
|
+
# Update state for accumulation
|
649
|
+
if output_data["token_ids"]:
|
650
|
+
state.output_ids.extend(output_data["token_ids"])
|
651
|
+
|
652
|
+
await state.out_queue.put(output_data)
|
653
|
+
|
654
|
+
# Handle completion
|
655
|
+
if output_data["finished"]:
|
656
|
+
state.finished = True
|
657
|
+
state.finished_time = now
|
658
|
+
state.stream_finished = True
|
659
|
+
state.event.set()
|
660
|
+
|
661
|
+
# Remove from tracking after a delay
|
662
|
+
async def cleanup():
|
663
|
+
await asyncio.sleep(5.0)
|
664
|
+
if rid in self.rid_to_state:
|
665
|
+
del self.rid_to_state[rid]
|
666
|
+
|
667
|
+
asyncio.create_task(cleanup())
|
668
|
+
|
669
|
+
async def _handle_embedding_output(self, batch_out: BatchEmbeddingOutput):
|
670
|
+
"""Handle batch embedding output from scheduler."""
|
671
|
+
for i, rid in enumerate(batch_out.rids):
|
672
|
+
if rid not in self.rid_to_state:
|
673
|
+
continue
|
674
|
+
|
675
|
+
state = self.rid_to_state[rid]
|
676
|
+
|
677
|
+
# Create result
|
678
|
+
result = {
|
679
|
+
"request_id": rid,
|
680
|
+
"embedding": batch_out.embeddings[i],
|
681
|
+
"prompt_tokens": (
|
682
|
+
batch_out.prompt_tokens[i] if batch_out.prompt_tokens else 0
|
683
|
+
),
|
684
|
+
"finish_reason": (
|
685
|
+
batch_out.finish_reason[i] if batch_out.finish_reason else None
|
686
|
+
),
|
687
|
+
}
|
688
|
+
|
689
|
+
# Send result
|
690
|
+
await state.out_queue.put(result)
|
691
|
+
|
692
|
+
# Mark as finished
|
693
|
+
state.finished = True
|
694
|
+
state.finished_time = time.time()
|
695
|
+
state.event.set()
|
696
|
+
|
697
|
+
async def _handle_health_check_output(self, health_out: HealthCheckOutput):
|
698
|
+
"""Handle health check output from scheduler."""
|
699
|
+
rid = health_out.rid
|
700
|
+
|
701
|
+
if rid not in self.rid_to_state:
|
702
|
+
logger.warning(f"Health check output for unknown request: {rid}")
|
703
|
+
return
|
704
|
+
|
705
|
+
state = self.rid_to_state[rid]
|
706
|
+
|
707
|
+
# Create health check result
|
708
|
+
result = {
|
709
|
+
"request_id": rid,
|
710
|
+
"healthy": True, # If we got a response, scheduler is healthy
|
711
|
+
"output_text": (
|
712
|
+
health_out.output_str if hasattr(health_out, "output_str") else ""
|
713
|
+
),
|
714
|
+
"finish_reason": (
|
715
|
+
health_out.finish_reason
|
716
|
+
if hasattr(health_out, "finish_reason")
|
717
|
+
else "stop"
|
718
|
+
),
|
719
|
+
}
|
720
|
+
|
721
|
+
# Send result
|
722
|
+
await state.out_queue.put(result)
|
723
|
+
|
724
|
+
# Mark as finished
|
725
|
+
state.finished = True
|
726
|
+
state.finished_time = time.time()
|
727
|
+
state.event.set()
|
728
|
+
|
729
|
+
async def _send_to_scheduler(self, obj):
|
730
|
+
"""Send an object to the scheduler via ZMQ."""
|
731
|
+
try:
|
732
|
+
self.send_to_scheduler.send_pyobj(obj)
|
733
|
+
except Exception as e:
|
734
|
+
logger.error(f"Failed to send to scheduler: {e}")
|
735
|
+
raise
|
736
|
+
|
737
|
+
def record_request_for_crash_dump(self, obj):
|
738
|
+
"""Record request for potential crash dump."""
|
739
|
+
if len(self.crash_dump_request_list) < 100:
|
740
|
+
self.crash_dump_request_list.append(
|
741
|
+
{
|
742
|
+
"time": time.time(),
|
743
|
+
"request_id": getattr(obj, "rid", "unknown"),
|
744
|
+
"type": type(obj).__name__,
|
745
|
+
}
|
746
|
+
)
|
747
|
+
|
748
|
+
async def shutdown(self):
|
749
|
+
"""Gracefully shutdown the request manager."""
|
750
|
+
logger.info("Shutting down GrpcRequestManager")
|
751
|
+
self.gracefully_exit = True
|
752
|
+
|
753
|
+
# Cancel all asyncio tasks FIRST - this will interrupt blocked recv() calls
|
754
|
+
for task in list(self.asyncio_tasks):
|
755
|
+
if not task.done():
|
756
|
+
task.cancel()
|
757
|
+
|
758
|
+
# Give tasks a moment to process cancellation
|
759
|
+
if self.asyncio_tasks:
|
760
|
+
await asyncio.gather(*list(self.asyncio_tasks), return_exceptions=True)
|
761
|
+
|
762
|
+
# Cancel all pending requests
|
763
|
+
for rid, state in list(self.rid_to_state.items()):
|
764
|
+
if not state.finished:
|
765
|
+
await state.out_queue.put(
|
766
|
+
{"error": "Server shutting down", "shutdown": True}
|
767
|
+
)
|
768
|
+
state.finished = True
|
769
|
+
state.event.set()
|
770
|
+
|
771
|
+
# Wait for tasks to complete
|
772
|
+
if self.asyncio_tasks:
|
773
|
+
await asyncio.gather(*list(self.asyncio_tasks), return_exceptions=True)
|
774
|
+
|
775
|
+
# Shutdown bootstrap server if running
|
776
|
+
if self.bootstrap_server:
|
777
|
+
logger.info("Shutting down bootstrap server")
|
778
|
+
try:
|
779
|
+
if hasattr(self.bootstrap_server, "shutdown"):
|
780
|
+
if asyncio.iscoroutinefunction(self.bootstrap_server.shutdown):
|
781
|
+
await self.bootstrap_server.shutdown()
|
782
|
+
else:
|
783
|
+
self.bootstrap_server.shutdown()
|
784
|
+
except Exception as e:
|
785
|
+
logger.warning(f"Error shutting down bootstrap server: {e}")
|
786
|
+
|
787
|
+
# Close ZMQ sockets
|
788
|
+
self.recv_from_scheduler.close()
|
789
|
+
self.send_to_scheduler.close()
|
790
|
+
|
791
|
+
# Terminate the ZMQ context - this is critical for asyncio loop to exit cleanly
|
792
|
+
self.context.term()
|
793
|
+
|
794
|
+
logger.info("GrpcRequestManager shutdown complete")
|
795
|
+
|
796
|
+
def get_server_info(self) -> Dict[str, Any]:
|
797
|
+
"""Get server information for health checks."""
|
798
|
+
return {
|
799
|
+
"active_requests": len(self.rid_to_state),
|
800
|
+
"paused": self.is_pause,
|
801
|
+
"last_receive_time": self.last_receive_tstamp,
|
802
|
+
}
|
803
|
+
|
804
|
+
def auto_create_handle_loop(self):
|
805
|
+
"""Automatically create and start the handle_loop task, matching TokenizerManager pattern."""
|
806
|
+
if self.no_create_loop:
|
807
|
+
return
|
808
|
+
|
809
|
+
self.no_create_loop = True
|
810
|
+
loop = asyncio.get_event_loop()
|
811
|
+
self.asyncio_tasks.add(
|
812
|
+
loop.create_task(print_exception_wrapper(self.handle_loop))
|
813
|
+
)
|
814
|
+
|
815
|
+
self.event_loop = loop
|
816
|
+
|
817
|
+
# We cannot add signal handler when the grpc manager is not in
|
818
|
+
# the main thread due to the CPython limitation.
|
819
|
+
if threading.current_thread() is threading.main_thread():
|
820
|
+
signal_handler = GrpcSignalHandler(self)
|
821
|
+
loop.add_signal_handler(signal.SIGTERM, signal_handler.sigterm_handler)
|
822
|
+
# Update the signal handler for the process. It overrides the sigquit handler in the launch phase.
|
823
|
+
loop.add_signal_handler(
|
824
|
+
signal.SIGQUIT, signal_handler.running_phase_sigquit_handler
|
825
|
+
)
|
826
|
+
else:
|
827
|
+
logger.warning(
|
828
|
+
"Signal handler is not added because the grpc request manager is "
|
829
|
+
"not in the main thread. This disables graceful shutdown of the "
|
830
|
+
"grpc request manager when SIGTERM is received."
|
831
|
+
)
|
832
|
+
self.asyncio_tasks.add(
|
833
|
+
loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
|
834
|
+
)
|
835
|
+
|
836
|
+
async def sigterm_watchdog(self):
|
837
|
+
"""Watchdog to handle SIGTERM gracefully, matching TokenizerManager pattern."""
|
838
|
+
while not self.gracefully_exit:
|
839
|
+
await asyncio.sleep(1.0)
|
840
|
+
|
841
|
+
|
842
|
+
async def print_exception_wrapper(func):
|
843
|
+
"""
|
844
|
+
Sometimes an asyncio function does not print exception.
|
845
|
+
We do another wrapper to handle the exception.
|
846
|
+
"""
|
847
|
+
try:
|
848
|
+
await func()
|
849
|
+
except Exception:
|
850
|
+
traceback = get_exception_traceback()
|
851
|
+
logger.error(f"GrpcRequestManager hit an exception: {traceback}")
|
852
|
+
if hasattr(func, "__self__") and isinstance(func.__self__, GrpcRequestManager):
|
853
|
+
func.__self__.dump_requests_before_crash()
|
854
|
+
kill_process_tree(os.getpid(), include_parent=True)
|
855
|
+
sys.exit(1)
|