sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +330 -156
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +8 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +134 -23
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +70 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +66 -66
- sglang/srt/entrypoints/grpc_server.py +431 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +120 -8
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +42 -4
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +18 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +4 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +12 -8
- sglang/srt/layers/attention/flashinfer_backend.py +248 -21
- sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +11 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +45 -15
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +147 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
- sglang/srt/layers/moe/ep_moe/layer.py +119 -397
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +17 -1
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +5 -30
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +673 -16
- sglang/srt/layers/sampler.py +36 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +66 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +399 -499
- sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +378 -90
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +435 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +28 -23
- sglang/srt/model_executor/model_runner.py +379 -139
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +273 -98
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +14 -37
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +5 -5
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3_moe.py +19 -35
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +6 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +577 -73
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +38 -28
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +192 -47
- sglang/srt/utils/hf_transformers_utils.py +40 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +232 -99
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -14,15 +14,17 @@ limitations under the License.
|
|
|
14
14
|
"""
|
|
15
15
|
|
|
16
16
|
import logging
|
|
17
|
-
import math
|
|
18
17
|
import threading
|
|
19
18
|
import time
|
|
20
|
-
from queue import Empty, Full,
|
|
21
|
-
from typing import TYPE_CHECKING, List, NamedTuple, Optional
|
|
19
|
+
from queue import Empty, Full, Queue
|
|
20
|
+
from typing import TYPE_CHECKING, List, NamedTuple, Optional
|
|
22
21
|
|
|
23
22
|
import torch
|
|
24
23
|
|
|
25
|
-
from sglang.srt.mem_cache.hicache_storage import
|
|
24
|
+
from sglang.srt.mem_cache.hicache_storage import (
|
|
25
|
+
HiCacheStorageConfig,
|
|
26
|
+
HiCacheStorageExtraInfo,
|
|
27
|
+
)
|
|
26
28
|
|
|
27
29
|
if TYPE_CHECKING:
|
|
28
30
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
|
@@ -38,7 +40,7 @@ from sglang.srt.layers.dp_attention import (
|
|
|
38
40
|
get_attention_tp_size,
|
|
39
41
|
is_dp_attention_enabled,
|
|
40
42
|
)
|
|
41
|
-
from sglang.srt.mem_cache.memory_pool import
|
|
43
|
+
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
|
42
44
|
|
|
43
45
|
logger = logging.getLogger(__name__)
|
|
44
46
|
|
|
@@ -191,12 +193,14 @@ class StorageOperation:
|
|
|
191
193
|
token_ids: List[int],
|
|
192
194
|
last_hash: Optional[str] = None,
|
|
193
195
|
hash_value: Optional[List[str]] = None,
|
|
196
|
+
prefix_keys: Optional[List[str]] = None,
|
|
194
197
|
):
|
|
195
198
|
self.host_indices = host_indices
|
|
196
199
|
self.token_ids = token_ids
|
|
197
200
|
self.last_hash = last_hash
|
|
198
201
|
self.completed_tokens = 0
|
|
199
202
|
self.hash_value = hash_value if hash_value is not None else []
|
|
203
|
+
self.prefix_keys = prefix_keys
|
|
200
204
|
|
|
201
205
|
self.id = StorageOperation.counter
|
|
202
206
|
StorageOperation.counter += 1
|
|
@@ -212,6 +216,7 @@ class PrefetchOperation(StorageOperation):
|
|
|
212
216
|
host_indices: torch.Tensor,
|
|
213
217
|
token_ids: List[int],
|
|
214
218
|
last_hash: Optional[str] = None,
|
|
219
|
+
prefix_keys: Optional[List[str]] = None,
|
|
215
220
|
):
|
|
216
221
|
self.request_id = request_id
|
|
217
222
|
|
|
@@ -219,7 +224,7 @@ class PrefetchOperation(StorageOperation):
|
|
|
219
224
|
self._terminated_flag = False
|
|
220
225
|
self.start_time = time.monotonic()
|
|
221
226
|
|
|
222
|
-
super().__init__(host_indices, token_ids, last_hash)
|
|
227
|
+
super().__init__(host_indices, token_ids, last_hash, prefix_keys=prefix_keys)
|
|
223
228
|
|
|
224
229
|
def increment(self, num_tokens: int):
|
|
225
230
|
with self._lock:
|
|
@@ -550,12 +555,13 @@ class HiCacheController:
|
|
|
550
555
|
host_indices: torch.Tensor,
|
|
551
556
|
new_input_tokens: List[int],
|
|
552
557
|
last_hash: Optional[str] = None,
|
|
558
|
+
prefix_keys: Optional[List[str]] = None,
|
|
553
559
|
) -> PrefetchOperation:
|
|
554
560
|
"""
|
|
555
561
|
Prefetch KV caches from storage backend to host memory.
|
|
556
562
|
"""
|
|
557
563
|
operation = PrefetchOperation(
|
|
558
|
-
request_id, host_indices, new_input_tokens, last_hash
|
|
564
|
+
request_id, host_indices, new_input_tokens, last_hash, prefix_keys
|
|
559
565
|
)
|
|
560
566
|
self.prefetch_queue.put(operation)
|
|
561
567
|
return operation
|
|
@@ -571,8 +577,12 @@ class HiCacheController:
|
|
|
571
577
|
for page in pages:
|
|
572
578
|
self.host_mem_release_queue.put(page)
|
|
573
579
|
|
|
574
|
-
def _page_get_zero_copy(
|
|
575
|
-
|
|
580
|
+
def _page_get_zero_copy(
|
|
581
|
+
self, operation, hash_values, host_indices, extra_info=None
|
|
582
|
+
):
|
|
583
|
+
results = self.storage_backend.batch_get_v1(
|
|
584
|
+
hash_values, host_indices, extra_info
|
|
585
|
+
)
|
|
576
586
|
inc = 0
|
|
577
587
|
for i in range(len(hash_values)):
|
|
578
588
|
if not results[i]:
|
|
@@ -584,7 +594,7 @@ class HiCacheController:
|
|
|
584
594
|
operation.increment(inc)
|
|
585
595
|
|
|
586
596
|
# todo: deprecate
|
|
587
|
-
def _generic_page_get(self, operation, hash_values, host_indices):
|
|
597
|
+
def _generic_page_get(self, operation, hash_values, host_indices, extra_info=None):
|
|
588
598
|
dummy_page_dst = [
|
|
589
599
|
self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
|
|
590
600
|
]
|
|
@@ -608,6 +618,7 @@ class HiCacheController:
|
|
|
608
618
|
|
|
609
619
|
def _page_transfer(self, operation):
|
|
610
620
|
# Transfer batch by batch
|
|
621
|
+
prefix_keys = operation.prefix_keys
|
|
611
622
|
for i in range(0, len(operation.hash_value), self.storage_batch_size):
|
|
612
623
|
batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
|
|
613
624
|
batch_host_indices = operation.host_indices[
|
|
@@ -615,7 +626,8 @@ class HiCacheController:
|
|
|
615
626
|
]
|
|
616
627
|
prev_completed_tokens = operation.completed_tokens
|
|
617
628
|
# Get one batch token, and update the completed_tokens if succeed
|
|
618
|
-
|
|
629
|
+
extra_info = HiCacheStorageExtraInfo(prefix_keys=prefix_keys)
|
|
630
|
+
self.page_get_func(operation, batch_hashes, batch_host_indices, extra_info)
|
|
619
631
|
# Check termination
|
|
620
632
|
if (
|
|
621
633
|
operation.completed_tokens
|
|
@@ -623,6 +635,10 @@ class HiCacheController:
|
|
|
623
635
|
):
|
|
624
636
|
operation.mark_terminate()
|
|
625
637
|
break # Some operations fail or operation terminated by controller
|
|
638
|
+
|
|
639
|
+
if prefix_keys and len(prefix_keys) > 0:
|
|
640
|
+
prefix_keys += batch_hashes
|
|
641
|
+
|
|
626
642
|
# release pre-allocated memory
|
|
627
643
|
self.append_host_mem_release(
|
|
628
644
|
operation.host_indices[operation.completed_tokens :]
|
|
@@ -656,6 +672,7 @@ class HiCacheController:
|
|
|
656
672
|
def _storage_hit_query(self, operation) -> tuple[list[str], int]:
|
|
657
673
|
last_hash = operation.last_hash
|
|
658
674
|
tokens_to_fetch = operation.token_ids
|
|
675
|
+
prefix_keys = operation.prefix_keys.copy() if operation.prefix_keys else None
|
|
659
676
|
|
|
660
677
|
storage_query_count = 0
|
|
661
678
|
hash_value = []
|
|
@@ -673,11 +690,15 @@ class HiCacheController:
|
|
|
673
690
|
batch_tokens[i : i + self.page_size], last_hash
|
|
674
691
|
)
|
|
675
692
|
batch_hashes.append(last_hash)
|
|
676
|
-
|
|
693
|
+
extra_info = HiCacheStorageExtraInfo(prefix_keys=prefix_keys)
|
|
694
|
+
hit_page_num = self.storage_backend.batch_exists(batch_hashes, extra_info)
|
|
677
695
|
hash_value.extend(batch_hashes[:hit_page_num])
|
|
678
696
|
storage_query_count += hit_page_num * self.page_size
|
|
679
697
|
if hit_page_num < len(batch_hashes):
|
|
680
698
|
break
|
|
699
|
+
if prefix_keys and len(prefix_keys) > 0:
|
|
700
|
+
prefix_keys += batch_hashes
|
|
701
|
+
|
|
681
702
|
return hash_value, storage_query_count
|
|
682
703
|
|
|
683
704
|
def prefetch_thread_func(self):
|
|
@@ -734,28 +755,34 @@ class HiCacheController:
|
|
|
734
755
|
host_indices: torch.Tensor,
|
|
735
756
|
token_ids: List[int],
|
|
736
757
|
hash_value: Optional[List[str]] = None,
|
|
758
|
+
prefix_keys: Optional[List[str]] = None,
|
|
737
759
|
) -> int:
|
|
738
760
|
"""
|
|
739
761
|
Write KV caches from host memory to storage backend.
|
|
740
762
|
"""
|
|
741
|
-
operation = StorageOperation(
|
|
763
|
+
operation = StorageOperation(
|
|
764
|
+
host_indices, token_ids, hash_value=hash_value, prefix_keys=prefix_keys
|
|
765
|
+
)
|
|
742
766
|
self.backup_queue.put(operation)
|
|
743
767
|
return operation.id
|
|
744
768
|
|
|
745
769
|
# todo: deprecate
|
|
746
|
-
def _generic_page_set(self, hash_values, host_indices) -> bool:
|
|
770
|
+
def _generic_page_set(self, hash_values, host_indices, extra_info=None) -> bool:
|
|
747
771
|
data = [
|
|
748
772
|
self.mem_pool_host.get_data_page(host_indices[i * self.page_size])
|
|
749
773
|
for i in range(len(hash_values))
|
|
750
774
|
]
|
|
751
775
|
return self.storage_backend.batch_set(hash_values, data)
|
|
752
776
|
|
|
753
|
-
def _page_set_zero_copy(self, hash_values, host_indices) -> bool:
|
|
754
|
-
return all(
|
|
777
|
+
def _page_set_zero_copy(self, hash_values, host_indices, extra_info=None) -> bool:
|
|
778
|
+
return all(
|
|
779
|
+
self.storage_backend.batch_set_v1(hash_values, host_indices, extra_info)
|
|
780
|
+
)
|
|
755
781
|
|
|
756
782
|
# Backup batch by batch
|
|
757
783
|
def _page_backup(self, operation):
|
|
758
784
|
# Backup batch by batch
|
|
785
|
+
prefix_keys = operation.prefix_keys
|
|
759
786
|
for i in range(0, len(operation.hash_value), self.storage_batch_size):
|
|
760
787
|
batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
|
|
761
788
|
batch_host_indices = operation.host_indices[
|
|
@@ -763,12 +790,16 @@ class HiCacheController:
|
|
|
763
790
|
]
|
|
764
791
|
# Set one batch token, and record if success.
|
|
765
792
|
# todo: allow partial success
|
|
766
|
-
|
|
793
|
+
extra_info = HiCacheStorageExtraInfo(prefix_keys=prefix_keys)
|
|
794
|
+
success = self.page_set_func(batch_hashes, batch_host_indices, extra_info)
|
|
767
795
|
if not success:
|
|
768
796
|
logger.warning(
|
|
769
797
|
f"Write page to storage: {len(batch_hashes)} pages failed."
|
|
770
798
|
)
|
|
771
799
|
break
|
|
800
|
+
|
|
801
|
+
if prefix_keys and len(prefix_keys) > 0:
|
|
802
|
+
prefix_keys += batch_hashes
|
|
772
803
|
operation.completed_tokens += self.page_size * len(batch_hashes)
|
|
773
804
|
|
|
774
805
|
def backup_thread_func(self):
|
|
@@ -21,7 +21,7 @@ import threading
|
|
|
21
21
|
import time
|
|
22
22
|
from collections import deque
|
|
23
23
|
from enum import Enum, auto
|
|
24
|
-
from typing import List
|
|
24
|
+
from typing import List, Optional
|
|
25
25
|
|
|
26
26
|
import psutil
|
|
27
27
|
import setproctitle
|
|
@@ -36,14 +36,19 @@ from sglang.srt.managers.io_struct import (
|
|
|
36
36
|
)
|
|
37
37
|
from sglang.srt.managers.schedule_batch import Req
|
|
38
38
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
|
39
|
-
from sglang.srt.server_args import
|
|
40
|
-
|
|
39
|
+
from sglang.srt.server_args import (
|
|
40
|
+
DP_ATTENTION_HANDSHAKE_PORT_DELTA,
|
|
41
|
+
PortArgs,
|
|
42
|
+
ServerArgs,
|
|
43
|
+
)
|
|
41
44
|
from sglang.srt.utils import (
|
|
42
45
|
bind_port,
|
|
43
46
|
configure_logger,
|
|
44
47
|
get_zmq_socket,
|
|
45
48
|
kill_itself_when_parent_died,
|
|
49
|
+
maybe_reindex_device_id,
|
|
46
50
|
)
|
|
51
|
+
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
|
47
52
|
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
|
48
53
|
|
|
49
54
|
logger = logging.getLogger(__name__)
|
|
@@ -135,27 +140,20 @@ class DataParallelController:
|
|
|
135
140
|
# Load balance budget
|
|
136
141
|
self.dp_budget = DPBudget()
|
|
137
142
|
|
|
143
|
+
# To protect changing env vars to set CUDA_VISIBLE_DEVICES.
|
|
144
|
+
self.env_lock = threading.Lock()
|
|
145
|
+
|
|
138
146
|
# Launch data parallel workers
|
|
139
147
|
self.scheduler_procs = []
|
|
140
148
|
self.workers: List[zmq.Socket] = [None] * server_args.dp_size
|
|
141
149
|
|
|
142
150
|
if server_args.enable_dp_attention:
|
|
143
|
-
|
|
151
|
+
self.launch_dp_attention_schedulers(server_args, port_args)
|
|
144
152
|
self.control_message_step = server_args.tp_size
|
|
145
153
|
else:
|
|
146
|
-
|
|
154
|
+
self.launch_dp_schedulers(server_args, port_args)
|
|
147
155
|
self.control_message_step = 1
|
|
148
156
|
|
|
149
|
-
# Only node rank 0 runs the real data parallel controller that dispatches the requests.
|
|
150
|
-
if server_args.node_rank == 0:
|
|
151
|
-
for dp_rank in range(server_args.dp_size):
|
|
152
|
-
self.workers[dp_rank] = get_zmq_socket(
|
|
153
|
-
self.context,
|
|
154
|
-
zmq.PUSH,
|
|
155
|
-
dp_port_args[dp_rank].scheduler_input_ipc_name,
|
|
156
|
-
True,
|
|
157
|
-
)
|
|
158
|
-
|
|
159
157
|
self.max_req_input_len = None
|
|
160
158
|
|
|
161
159
|
self.init_dispatcher()
|
|
@@ -188,13 +186,11 @@ class DataParallelController:
|
|
|
188
186
|
|
|
189
187
|
threads = []
|
|
190
188
|
sockets = []
|
|
191
|
-
dp_port_args = []
|
|
192
189
|
ready_events = []
|
|
193
190
|
for dp_rank in range(server_args.dp_size):
|
|
194
191
|
tmp_port_args = PortArgs.init_new(server_args)
|
|
195
192
|
tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
|
|
196
193
|
tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
|
|
197
|
-
dp_port_args.append(tmp_port_args)
|
|
198
194
|
|
|
199
195
|
# This port is checked free in PortArgs.init_new.
|
|
200
196
|
# We hold it first so that the next dp worker gets a different port
|
|
@@ -213,6 +209,14 @@ class DataParallelController:
|
|
|
213
209
|
server_args.tp_size * server_args.pp_size * server_args.gpu_id_step
|
|
214
210
|
)
|
|
215
211
|
|
|
212
|
+
if server_args.node_rank == 0:
|
|
213
|
+
self.workers[dp_rank] = get_zmq_socket(
|
|
214
|
+
self.context,
|
|
215
|
+
zmq.PUSH,
|
|
216
|
+
tmp_port_args.scheduler_input_ipc_name,
|
|
217
|
+
True,
|
|
218
|
+
)
|
|
219
|
+
|
|
216
220
|
# Free all sockets before starting the threads to launch TP workers
|
|
217
221
|
for sock in sockets:
|
|
218
222
|
sock.close()
|
|
@@ -223,8 +227,6 @@ class DataParallelController:
|
|
|
223
227
|
for event in ready_events:
|
|
224
228
|
event.wait()
|
|
225
229
|
|
|
226
|
-
return dp_port_args
|
|
227
|
-
|
|
228
230
|
def launch_tensor_parallel_group_thread(
|
|
229
231
|
self,
|
|
230
232
|
server_args: ServerArgs,
|
|
@@ -241,19 +243,115 @@ class DataParallelController:
|
|
|
241
243
|
while True:
|
|
242
244
|
time.sleep(30 * 24 * 3600)
|
|
243
245
|
|
|
244
|
-
def
|
|
245
|
-
self
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
246
|
+
def _broadcast_worker_ports(
|
|
247
|
+
self, server_args: ServerArgs, worker_ports: Optional[List[int]] = None
|
|
248
|
+
) -> List[int]:
|
|
249
|
+
"""Broadcast worker ports from node 0 to all other nodes.
|
|
250
|
+
|
|
251
|
+
Node 0 acts as the server, waiting for all other nodes to connect and
|
|
252
|
+
sending them the pre-allocated worker ports. Other nodes act as clients,
|
|
253
|
+
connecting to node 0 to receive their copy of the worker ports.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
server_args: Server arguments containing node configuration.
|
|
257
|
+
worker_ports: Pre-allocated worker ports to broadcast.
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
List of worker ports (same on all nodes after broadcast).
|
|
261
|
+
"""
|
|
262
|
+
# Determine the endpoint for inter-node communication
|
|
263
|
+
if server_args.dist_init_addr is None:
|
|
264
|
+
endpoint = f"tcp://127.0.0.1:{server_args.port + DP_ATTENTION_HANDSHAKE_PORT_DELTA}"
|
|
265
|
+
else:
|
|
266
|
+
endpoint = f"tcp://{server_args.dist_init_addr}"
|
|
267
|
+
|
|
268
|
+
if server_args.node_rank == 0:
|
|
269
|
+
# Node 0: Broadcast worker ports to all other nodes
|
|
270
|
+
return self._broadcast_ports_as_server(
|
|
271
|
+
endpoint, server_args.nnodes - 1, worker_ports
|
|
272
|
+
)
|
|
273
|
+
else:
|
|
274
|
+
# Other nodes: Receive worker ports from node 0
|
|
275
|
+
return self._receive_ports_as_client(endpoint, server_args.node_rank)
|
|
276
|
+
|
|
277
|
+
def _broadcast_ports_as_server(
|
|
278
|
+
self, endpoint: str, expected_clients: int, worker_ports: List[int]
|
|
279
|
+
) -> List[int]:
|
|
280
|
+
"""Broadcast worker ports to all client nodes."""
|
|
281
|
+
logger.debug(f"Broadcasting worker ports to {expected_clients} client nodes")
|
|
282
|
+
logger.debug(f"Worker ports: {worker_ports}")
|
|
283
|
+
|
|
284
|
+
rep_socket = get_zmq_socket(self.context, zmq.REP, endpoint, True)
|
|
285
|
+
|
|
286
|
+
try:
|
|
287
|
+
connected_clients = 0
|
|
288
|
+
while connected_clients < expected_clients:
|
|
289
|
+
# Wait for client handshake
|
|
290
|
+
client_rank = rep_socket.recv().decode()
|
|
291
|
+
logger.debug(f"Received handshake from node {client_rank}")
|
|
292
|
+
|
|
293
|
+
# Send worker ports to client
|
|
294
|
+
rep_socket.send_pyobj(worker_ports)
|
|
295
|
+
connected_clients += 1
|
|
296
|
+
logger.debug(
|
|
297
|
+
f"Sent worker ports to {connected_clients}/{expected_clients} nodes"
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
logger.debug("Worker port broadcast completed")
|
|
301
|
+
return worker_ports
|
|
302
|
+
finally:
|
|
303
|
+
rep_socket.close()
|
|
304
|
+
|
|
305
|
+
def _receive_ports_as_client(self, endpoint: str, node_rank: int) -> List[int]:
|
|
306
|
+
"""Receive worker ports from the server node."""
|
|
307
|
+
logger.debug(f"Connecting to node 0 to receive worker ports")
|
|
308
|
+
|
|
309
|
+
req_socket = get_zmq_socket(self.context, zmq.REQ, endpoint, False)
|
|
310
|
+
req_socket.setsockopt(zmq.RCVTIMEO, 60 * 1000) # 1 minute timeout
|
|
311
|
+
req_socket.setsockopt(zmq.SNDTIMEO, 60 * 1000)
|
|
312
|
+
|
|
313
|
+
try:
|
|
314
|
+
# Send handshake with our node rank
|
|
315
|
+
req_socket.send(str(node_rank).encode())
|
|
316
|
+
|
|
317
|
+
# Receive worker ports
|
|
318
|
+
worker_ports = req_socket.recv_pyobj()
|
|
319
|
+
logger.debug(f"Received {len(worker_ports)} worker ports from node 0")
|
|
320
|
+
return worker_ports
|
|
321
|
+
except zmq.Again:
|
|
322
|
+
logger.error("Timeout waiting for worker ports from node 0")
|
|
323
|
+
raise RuntimeError(
|
|
324
|
+
"Failed to receive worker ports from node 0 within timeout"
|
|
325
|
+
)
|
|
326
|
+
finally:
|
|
327
|
+
req_socket.close()
|
|
328
|
+
|
|
329
|
+
def launch_dp_attention_schedulers(
|
|
330
|
+
self, server_args: ServerArgs, port_args: PortArgs
|
|
331
|
+
):
|
|
332
|
+
# Pre-allocate worker ports on node 0 to avoid conflicts
|
|
333
|
+
worker_ports = []
|
|
334
|
+
if server_args.node_rank == 0:
|
|
335
|
+
for dp_rank in range(server_args.dp_size):
|
|
336
|
+
port_and_socket = get_zmq_socket(self.context, zmq.PUSH)
|
|
337
|
+
worker_ports.append(port_and_socket[0])
|
|
338
|
+
self.workers[dp_rank] = port_and_socket[1]
|
|
339
|
+
logger.debug(f"Assigned port {port_and_socket[0]} to worker {dp_rank}")
|
|
340
|
+
|
|
341
|
+
broadcasted_ports = self._broadcast_worker_ports(
|
|
342
|
+
server_args, worker_ports if worker_ports else None
|
|
343
|
+
)
|
|
344
|
+
self.launch_tensor_parallel_group(
|
|
345
|
+
server_args, port_args, 0, None, broadcasted_ports
|
|
346
|
+
)
|
|
250
347
|
|
|
251
348
|
def launch_tensor_parallel_group(
|
|
252
349
|
self,
|
|
253
350
|
server_args: ServerArgs,
|
|
254
351
|
port_args: PortArgs,
|
|
255
352
|
base_gpu_id: int,
|
|
256
|
-
dp_rank: int,
|
|
353
|
+
dp_rank: Optional[int],
|
|
354
|
+
worker_ports: Optional[List[int]] = None,
|
|
257
355
|
):
|
|
258
356
|
if not server_args.enable_dp_attention:
|
|
259
357
|
logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
|
|
@@ -290,7 +388,9 @@ class DataParallelController:
|
|
|
290
388
|
server_args.dp_size,
|
|
291
389
|
)
|
|
292
390
|
# compute zmq ports for this dp rank
|
|
293
|
-
rank_port_args = PortArgs.init_new(
|
|
391
|
+
rank_port_args = PortArgs.init_new(
|
|
392
|
+
server_args, dp_rank, worker_ports
|
|
393
|
+
)
|
|
294
394
|
# Data parallelism reuses the tensor parallelism group,
|
|
295
395
|
# so all dp ranks should use the same nccl port.
|
|
296
396
|
rank_port_args.nccl_port = port_args.nccl_port
|
|
@@ -303,21 +403,22 @@ class DataParallelController:
|
|
|
303
403
|
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
|
|
304
404
|
)
|
|
305
405
|
moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
406
|
+
with self.env_lock, maybe_reindex_device_id(gpu_id) as gpu_id:
|
|
407
|
+
proc = mp.Process(
|
|
408
|
+
target=run_scheduler_process,
|
|
409
|
+
args=(
|
|
410
|
+
server_args,
|
|
411
|
+
rank_port_args,
|
|
412
|
+
gpu_id,
|
|
413
|
+
tp_rank,
|
|
414
|
+
moe_ep_rank,
|
|
415
|
+
pp_rank,
|
|
416
|
+
dp_rank,
|
|
417
|
+
writer,
|
|
418
|
+
),
|
|
419
|
+
)
|
|
420
|
+
with memory_saver_adapter.configure_subprocess():
|
|
421
|
+
proc.start()
|
|
321
422
|
self.scheduler_procs.append(proc)
|
|
322
423
|
scheduler_pipe_readers.append(reader)
|
|
323
424
|
|
|
@@ -346,6 +447,9 @@ class DataParallelController:
|
|
|
346
447
|
self.workers
|
|
347
448
|
)
|
|
348
449
|
else:
|
|
450
|
+
assert (
|
|
451
|
+
req.bootstrap_room is not None
|
|
452
|
+
), "req.bootstrap_room should not be None. Do not send requests directly to prefill or decode instances, but send to the router instead."
|
|
349
453
|
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
|
|
350
454
|
|
|
351
455
|
def shortest_queue_scheduler(self, req):
|
|
@@ -31,7 +31,6 @@ from sglang.srt.managers.io_struct import (
|
|
|
31
31
|
BatchStrOutput,
|
|
32
32
|
BatchTokenIDOutput,
|
|
33
33
|
FreezeGCReq,
|
|
34
|
-
MultiTokenizerRegisterReq,
|
|
35
34
|
)
|
|
36
35
|
from sglang.srt.managers.multi_tokenizer_mixin import MultiHttpWorkerDetokenizerMixin
|
|
37
36
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
|
@@ -104,12 +103,12 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
|
|
|
104
103
|
(BatchEmbeddingOutput, self.handle_batch_embedding_out),
|
|
105
104
|
(BatchTokenIDOutput, self.handle_batch_token_id_out),
|
|
106
105
|
(BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
|
|
107
|
-
(MultiTokenizerRegisterReq, lambda x: x),
|
|
108
106
|
(FreezeGCReq, self.handle_freeze_gc_req),
|
|
109
107
|
]
|
|
110
108
|
)
|
|
111
109
|
|
|
112
110
|
self.is_tool_call_parser_gpt_oss = server_args.tool_call_parser == "gpt-oss"
|
|
111
|
+
self.disable_tokenizer_batch_decode = server_args.disable_tokenizer_batch_decode
|
|
113
112
|
|
|
114
113
|
def event_loop(self):
|
|
115
114
|
"""The event loop that handles requests"""
|
|
@@ -142,6 +141,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
|
|
|
142
141
|
if output[-1] == 200012 and self.is_tool_call_parser_gpt_oss:
|
|
143
142
|
return output
|
|
144
143
|
assert len(output) > 0
|
|
144
|
+
# NOTE: We can always assume the last token is the matched stop token
|
|
145
145
|
return output[:-1]
|
|
146
146
|
return output
|
|
147
147
|
|
|
@@ -177,17 +177,39 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
|
|
|
177
177
|
)
|
|
178
178
|
surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset])
|
|
179
179
|
|
|
180
|
-
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
180
|
+
# TODO(lmzheng): better handle skip_special_tokens/spaces_between_special_tokens per request
|
|
181
|
+
if self.disable_tokenizer_batch_decode:
|
|
182
|
+
surr_texts = [
|
|
183
|
+
self.tokenizer.decode(
|
|
184
|
+
surr, skip_special_tokens=skip, spaces_between_special_tokens=space
|
|
185
|
+
)
|
|
186
|
+
for surr, skip, space in zip(
|
|
187
|
+
surr_ids,
|
|
188
|
+
recv_obj.skip_special_tokens,
|
|
189
|
+
recv_obj.spaces_between_special_tokens,
|
|
190
|
+
)
|
|
191
|
+
]
|
|
192
|
+
read_texts = [
|
|
193
|
+
self.tokenizer.decode(
|
|
194
|
+
read, skip_special_tokens=skip, spaces_between_special_tokens=space
|
|
195
|
+
)
|
|
196
|
+
for read, skip, space in zip(
|
|
197
|
+
read_ids,
|
|
198
|
+
recv_obj.skip_special_tokens,
|
|
199
|
+
recv_obj.spaces_between_special_tokens,
|
|
200
|
+
)
|
|
201
|
+
]
|
|
202
|
+
else:
|
|
203
|
+
surr_texts = self.tokenizer.batch_decode(
|
|
204
|
+
surr_ids,
|
|
205
|
+
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
|
206
|
+
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
|
|
207
|
+
)
|
|
208
|
+
read_texts = self.tokenizer.batch_decode(
|
|
209
|
+
read_ids,
|
|
210
|
+
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
|
211
|
+
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
|
|
212
|
+
)
|
|
191
213
|
|
|
192
214
|
# Incremental decoding
|
|
193
215
|
output_strs = []
|
|
@@ -226,6 +248,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
|
|
|
226
248
|
|
|
227
249
|
return BatchStrOutput(
|
|
228
250
|
rids=recv_obj.rids,
|
|
251
|
+
http_worker_ipcs=recv_obj.http_worker_ipcs,
|
|
229
252
|
finished_reasons=recv_obj.finished_reasons,
|
|
230
253
|
output_strs=output_strs,
|
|
231
254
|
output_ids=recv_obj.decode_ids,
|
|
@@ -233,6 +256,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
|
|
|
233
256
|
completion_tokens=recv_obj.completion_tokens,
|
|
234
257
|
cached_tokens=recv_obj.cached_tokens,
|
|
235
258
|
spec_verify_ct=recv_obj.spec_verify_ct,
|
|
259
|
+
spec_accepted_tokens=recv_obj.spec_accepted_tokens,
|
|
236
260
|
input_token_logprobs_val=recv_obj.input_token_logprobs_val,
|
|
237
261
|
input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
|
|
238
262
|
output_token_logprobs_val=recv_obj.output_token_logprobs_val,
|
|
@@ -245,15 +269,18 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
|
|
|
245
269
|
input_token_ids_logprobs_idx=recv_obj.input_token_ids_logprobs_idx,
|
|
246
270
|
output_token_ids_logprobs_val=recv_obj.output_token_ids_logprobs_val,
|
|
247
271
|
output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx,
|
|
272
|
+
output_token_entropy_val=recv_obj.output_token_entropy_val,
|
|
248
273
|
output_hidden_states=recv_obj.output_hidden_states,
|
|
249
274
|
placeholder_tokens_idx=None,
|
|
250
275
|
placeholder_tokens_val=None,
|
|
276
|
+
token_steps=recv_obj.token_steps,
|
|
251
277
|
)
|
|
252
278
|
|
|
253
279
|
def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
|
|
254
280
|
outputs = self.tokenizer.detokenize(recv_obj)
|
|
255
281
|
return BatchMultimodalOutput(
|
|
256
282
|
rids=recv_obj.rids,
|
|
283
|
+
http_worker_ipcs=recv_obj.http_worker_ipcs,
|
|
257
284
|
finished_reasons=recv_obj.finished_reasons,
|
|
258
285
|
outputs=outputs,
|
|
259
286
|
prompt_tokens=recv_obj.prompt_tokens,
|