sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +54 -37
 - sglang/bench_one_batch_server.py +340 -34
 - sglang/bench_serving.py +340 -159
 - sglang/check_env.py +1 -1
 - sglang/compile_deep_gemm.py +6 -2
 - sglang/global_config.py +1 -25
 - sglang/lang/api.py +6 -0
 - sglang/lang/backend/runtime_endpoint.py +1 -1
 - sglang/lang/interpreter.py +1 -0
 - sglang/lang/ir.py +13 -0
 - sglang/launch_server.py +9 -2
 - sglang/profiler.py +20 -3
 - sglang/srt/_custom_ops.py +1 -1
 - sglang/srt/batch_invariant_ops/__init__.py +27 -0
 - sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
 - sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
 - sglang/srt/compilation/backend.py +437 -0
 - sglang/srt/compilation/compilation_config.py +20 -0
 - sglang/srt/compilation/compilation_counter.py +47 -0
 - sglang/srt/compilation/compile.py +210 -0
 - sglang/srt/compilation/compiler_interface.py +503 -0
 - sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
 - sglang/srt/compilation/fix_functionalization.py +134 -0
 - sglang/srt/compilation/fx_utils.py +83 -0
 - sglang/srt/compilation/inductor_pass.py +140 -0
 - sglang/srt/compilation/pass_manager.py +66 -0
 - sglang/srt/compilation/piecewise_context_manager.py +40 -0
 - sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
 - sglang/srt/configs/__init__.py +8 -0
 - sglang/srt/configs/deepseek_ocr.py +262 -0
 - sglang/srt/configs/deepseekvl2.py +194 -96
 - sglang/srt/configs/dots_ocr.py +64 -0
 - sglang/srt/configs/dots_vlm.py +2 -7
 - sglang/srt/configs/falcon_h1.py +309 -0
 - sglang/srt/configs/load_config.py +33 -2
 - sglang/srt/configs/mamba_utils.py +117 -0
 - sglang/srt/configs/model_config.py +284 -118
 - sglang/srt/configs/modelopt_config.py +30 -0
 - sglang/srt/configs/nemotron_h.py +286 -0
 - sglang/srt/configs/olmo3.py +105 -0
 - sglang/srt/configs/points_v15_chat.py +29 -0
 - sglang/srt/configs/qwen3_next.py +11 -47
 - sglang/srt/configs/qwen3_omni.py +613 -0
 - sglang/srt/configs/qwen3_vl.py +576 -0
 - sglang/srt/connector/remote_instance.py +1 -1
 - sglang/srt/constrained/base_grammar_backend.py +6 -1
 - sglang/srt/constrained/llguidance_backend.py +5 -0
 - sglang/srt/constrained/outlines_backend.py +1 -1
 - sglang/srt/constrained/outlines_jump_forward.py +1 -1
 - sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
 - sglang/srt/constrained/utils.py +12 -0
 - sglang/srt/constrained/xgrammar_backend.py +26 -15
 - sglang/srt/debug_utils/dumper.py +10 -3
 - sglang/srt/disaggregation/ascend/conn.py +2 -2
 - sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
 - sglang/srt/disaggregation/base/conn.py +17 -4
 - sglang/srt/disaggregation/common/conn.py +268 -98
 - sglang/srt/disaggregation/decode.py +172 -39
 - sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
 - sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
 - sglang/srt/disaggregation/fake/conn.py +11 -3
 - sglang/srt/disaggregation/mooncake/conn.py +203 -555
 - sglang/srt/disaggregation/nixl/conn.py +217 -63
 - sglang/srt/disaggregation/prefill.py +113 -270
 - sglang/srt/disaggregation/utils.py +36 -5
 - sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
 - sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
 - sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
 - sglang/srt/distributed/device_communicators/pynccl.py +24 -12
 - sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
 - sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
 - sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
 - sglang/srt/distributed/naive_distributed.py +5 -4
 - sglang/srt/distributed/parallel_state.py +203 -97
 - sglang/srt/elastic_ep/elastic_ep.py +74 -0
 - sglang/srt/entrypoints/context.py +3 -2
 - sglang/srt/entrypoints/engine.py +85 -65
 - sglang/srt/entrypoints/grpc_server.py +632 -305
 - sglang/srt/entrypoints/harmony_utils.py +2 -2
 - sglang/srt/entrypoints/http_server.py +169 -17
 - sglang/srt/entrypoints/http_server_engine.py +1 -7
 - sglang/srt/entrypoints/openai/protocol.py +327 -34
 - sglang/srt/entrypoints/openai/serving_base.py +74 -8
 - sglang/srt/entrypoints/openai/serving_chat.py +202 -118
 - sglang/srt/entrypoints/openai/serving_classify.py +204 -0
 - sglang/srt/entrypoints/openai/serving_completions.py +20 -4
 - sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
 - sglang/srt/entrypoints/openai/serving_responses.py +47 -2
 - sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
 - sglang/srt/environ.py +323 -0
 - sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
 - sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
 - sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
 - sglang/srt/eplb/expert_distribution.py +3 -4
 - sglang/srt/eplb/expert_location.py +30 -5
 - sglang/srt/eplb/expert_location_dispatch.py +2 -2
 - sglang/srt/eplb/expert_location_updater.py +2 -2
 - sglang/srt/function_call/base_format_detector.py +17 -18
 - sglang/srt/function_call/function_call_parser.py +21 -16
 - sglang/srt/function_call/glm4_moe_detector.py +4 -8
 - sglang/srt/function_call/gpt_oss_detector.py +24 -1
 - sglang/srt/function_call/json_array_parser.py +61 -0
 - sglang/srt/function_call/kimik2_detector.py +17 -4
 - sglang/srt/function_call/utils.py +98 -7
 - sglang/srt/grpc/compile_proto.py +245 -0
 - sglang/srt/grpc/grpc_request_manager.py +915 -0
 - sglang/srt/grpc/health_servicer.py +189 -0
 - sglang/srt/grpc/scheduler_launcher.py +181 -0
 - sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
 - sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
 - sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
 - sglang/srt/layers/activation.py +11 -7
 - sglang/srt/layers/attention/aiter_backend.py +17 -18
 - sglang/srt/layers/attention/ascend_backend.py +125 -10
 - sglang/srt/layers/attention/attention_registry.py +226 -0
 - sglang/srt/layers/attention/base_attn_backend.py +32 -4
 - sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
 - sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
 - sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
 - sglang/srt/layers/attention/fla/chunk.py +0 -1
 - sglang/srt/layers/attention/fla/chunk_o.py +1 -1
 - sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
 - sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
 - sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
 - sglang/srt/layers/attention/fla/index.py +0 -2
 - sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
 - sglang/srt/layers/attention/fla/utils.py +0 -3
 - sglang/srt/layers/attention/fla/wy_fast.py +0 -2
 - sglang/srt/layers/attention/flashattention_backend.py +52 -15
 - sglang/srt/layers/attention/flashinfer_backend.py +357 -212
 - sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
 - sglang/srt/layers/attention/flashmla_backend.py +9 -7
 - sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
 - sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
 - sglang/srt/layers/attention/intel_amx_backend.py +1 -1
 - sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
 - sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
 - sglang/srt/layers/attention/mamba/mamba.py +514 -1
 - sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
 - sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
 - sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
 - sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
 - sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
 - sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
 - sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
 - sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
 - sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
 - sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
 - sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
 - sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
 - sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
 - sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
 - sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
 - sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
 - sglang/srt/layers/attention/nsa/transform_index.py +144 -0
 - sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
 - sglang/srt/layers/attention/nsa/utils.py +23 -0
 - sglang/srt/layers/attention/nsa_backend.py +1201 -0
 - sglang/srt/layers/attention/tbo_backend.py +6 -6
 - sglang/srt/layers/attention/torch_flex_backend.py +325 -0
 - sglang/srt/layers/attention/triton_backend.py +249 -42
 - sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
 - sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
 - sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
 - sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
 - sglang/srt/layers/attention/utils.py +11 -7
 - sglang/srt/layers/attention/vision.py +61 -3
 - sglang/srt/layers/attention/wave_backend.py +4 -4
 - sglang/srt/layers/attention/xpu_backend.py +1028 -0
 - sglang/srt/layers/communicator.py +19 -7
 - sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
 - sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
 - sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
 - sglang/srt/layers/dp_attention.py +28 -1
 - sglang/srt/layers/elementwise.py +3 -1
 - sglang/srt/layers/layernorm.py +47 -15
 - sglang/srt/layers/linear.py +30 -5
 - sglang/srt/layers/logits_processor.py +161 -18
 - sglang/srt/layers/modelopt_utils.py +11 -0
 - sglang/srt/layers/moe/cutlass_moe.py +0 -2
 - sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
 - sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
 - sglang/srt/layers/moe/ep_moe/layer.py +243 -448
 - sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
 - sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
 - sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
 - sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
 - sglang/srt/layers/moe/moe_runner/runner.py +3 -0
 - sglang/srt/layers/moe/moe_runner/triton.py +3 -1
 - sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
 - sglang/srt/layers/moe/router.py +51 -15
 - sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
 - sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
 - sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
 - sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
 - sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
 - sglang/srt/layers/moe/topk.py +3 -2
 - sglang/srt/layers/moe/utils.py +27 -1
 - sglang/srt/layers/parameter.py +23 -6
 - sglang/srt/layers/quantization/__init__.py +2 -53
 - sglang/srt/layers/quantization/awq.py +183 -6
 - sglang/srt/layers/quantization/awq_triton.py +29 -0
 - sglang/srt/layers/quantization/base_config.py +20 -1
 - sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
 - sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
 - sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
 - sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
 - sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
 - sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
 - sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
 - sglang/srt/layers/quantization/fp8.py +86 -20
 - sglang/srt/layers/quantization/fp8_kernel.py +55 -10
 - sglang/srt/layers/quantization/fp8_utils.py +43 -15
 - sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
 - sglang/srt/layers/quantization/gptq.py +0 -1
 - sglang/srt/layers/quantization/int8_kernel.py +18 -2
 - sglang/srt/layers/quantization/marlin_utils.py +12 -0
 - sglang/srt/layers/quantization/modelopt_quant.py +141 -81
 - sglang/srt/layers/quantization/mxfp4.py +17 -34
 - sglang/srt/layers/quantization/petit.py +1 -1
 - sglang/srt/layers/quantization/quark/quark.py +3 -1
 - sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
 - sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
 - sglang/srt/layers/quantization/unquant.py +1 -4
 - sglang/srt/layers/quantization/utils.py +0 -1
 - sglang/srt/layers/quantization/w4afp8.py +51 -24
 - sglang/srt/layers/quantization/w8a8_int8.py +45 -27
 - sglang/srt/layers/radix_attention.py +59 -9
 - sglang/srt/layers/rotary_embedding.py +750 -46
 - sglang/srt/layers/sampler.py +84 -16
 - sglang/srt/layers/sparse_pooler.py +98 -0
 - sglang/srt/layers/utils.py +23 -1
 - sglang/srt/layers/vocab_parallel_embedding.py +4 -1
 - sglang/srt/lora/backend/base_backend.py +3 -3
 - sglang/srt/lora/backend/chunked_backend.py +348 -0
 - sglang/srt/lora/backend/triton_backend.py +9 -4
 - sglang/srt/lora/eviction_policy.py +139 -0
 - sglang/srt/lora/lora.py +7 -5
 - sglang/srt/lora/lora_manager.py +33 -7
 - sglang/srt/lora/lora_registry.py +1 -1
 - sglang/srt/lora/mem_pool.py +41 -17
 - sglang/srt/lora/triton_ops/__init__.py +4 -0
 - sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
 - sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
 - sglang/srt/lora/utils.py +7 -5
 - sglang/srt/managers/cache_controller.py +83 -152
 - sglang/srt/managers/data_parallel_controller.py +156 -87
 - sglang/srt/managers/detokenizer_manager.py +51 -24
 - sglang/srt/managers/io_struct.py +223 -129
 - sglang/srt/managers/mm_utils.py +49 -10
 - sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
 - sglang/srt/managers/multimodal_processor.py +1 -2
 - sglang/srt/managers/overlap_utils.py +130 -0
 - sglang/srt/managers/schedule_batch.py +340 -529
 - sglang/srt/managers/schedule_policy.py +158 -18
 - sglang/srt/managers/scheduler.py +665 -620
 - sglang/srt/managers/scheduler_input_blocker.py +1 -1
 - sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
 - sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
 - sglang/srt/managers/scheduler_pp_mixin.py +341 -0
 - sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
 - sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
 - sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
 - sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
 - sglang/srt/managers/tokenizer_manager.py +462 -226
 - sglang/srt/managers/tp_worker.py +217 -156
 - sglang/srt/managers/utils.py +79 -47
 - sglang/srt/mem_cache/allocator.py +21 -22
 - sglang/srt/mem_cache/allocator_ascend.py +42 -28
 - sglang/srt/mem_cache/base_prefix_cache.py +3 -3
 - sglang/srt/mem_cache/chunk_cache.py +20 -2
 - sglang/srt/mem_cache/common.py +480 -0
 - sglang/srt/mem_cache/evict_policy.py +38 -0
 - sglang/srt/mem_cache/hicache_storage.py +44 -2
 - sglang/srt/mem_cache/hiradix_cache.py +134 -34
 - sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
 - sglang/srt/mem_cache/memory_pool.py +602 -208
 - sglang/srt/mem_cache/memory_pool_host.py +134 -183
 - sglang/srt/mem_cache/multimodal_cache.py +0 -1
 - sglang/srt/mem_cache/radix_cache.py +263 -78
 - sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
 - sglang/srt/mem_cache/storage/__init__.py +10 -0
 - sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
 - sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
 - sglang/srt/mem_cache/storage/backend_factory.py +223 -0
 - sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
 - sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
 - sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
 - sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
 - sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
 - sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
 - sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
 - sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
 - sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
 - sglang/srt/mem_cache/swa_radix_cache.py +115 -58
 - sglang/srt/metrics/collector.py +113 -120
 - sglang/srt/metrics/func_timer.py +3 -8
 - sglang/srt/metrics/utils.py +8 -1
 - sglang/srt/model_executor/cpu_graph_runner.py +2 -2
 - sglang/srt/model_executor/cuda_graph_runner.py +81 -36
 - sglang/srt/model_executor/forward_batch_info.py +40 -50
 - sglang/srt/model_executor/model_runner.py +507 -319
 - sglang/srt/model_executor/npu_graph_runner.py +11 -5
 - sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
 - sglang/srt/model_loader/__init__.py +1 -1
 - sglang/srt/model_loader/loader.py +438 -37
 - sglang/srt/model_loader/utils.py +0 -1
 - sglang/srt/model_loader/weight_utils.py +200 -27
 - sglang/srt/models/apertus.py +2 -3
 - sglang/srt/models/arcee.py +2 -2
 - sglang/srt/models/bailing_moe.py +40 -56
 - sglang/srt/models/bailing_moe_nextn.py +3 -4
 - sglang/srt/models/bert.py +1 -1
 - sglang/srt/models/deepseek_nextn.py +25 -4
 - sglang/srt/models/deepseek_ocr.py +1516 -0
 - sglang/srt/models/deepseek_v2.py +793 -235
 - sglang/srt/models/dots_ocr.py +171 -0
 - sglang/srt/models/dots_vlm.py +0 -1
 - sglang/srt/models/dots_vlm_vit.py +1 -1
 - sglang/srt/models/falcon_h1.py +570 -0
 - sglang/srt/models/gemma3_causal.py +0 -2
 - sglang/srt/models/gemma3_mm.py +17 -1
 - sglang/srt/models/gemma3n_mm.py +2 -3
 - sglang/srt/models/glm4_moe.py +17 -40
 - sglang/srt/models/glm4_moe_nextn.py +4 -4
 - sglang/srt/models/glm4v.py +3 -2
 - sglang/srt/models/glm4v_moe.py +6 -6
 - sglang/srt/models/gpt_oss.py +12 -35
 - sglang/srt/models/grok.py +10 -23
 - sglang/srt/models/hunyuan.py +2 -7
 - sglang/srt/models/interns1.py +0 -1
 - sglang/srt/models/kimi_vl.py +1 -7
 - sglang/srt/models/kimi_vl_moonvit.py +4 -2
 - sglang/srt/models/llama.py +6 -2
 - sglang/srt/models/llama_eagle3.py +1 -1
 - sglang/srt/models/longcat_flash.py +6 -23
 - sglang/srt/models/longcat_flash_nextn.py +4 -15
 - sglang/srt/models/mimo.py +2 -13
 - sglang/srt/models/mimo_mtp.py +1 -2
 - sglang/srt/models/minicpmo.py +7 -5
 - sglang/srt/models/mixtral.py +1 -4
 - sglang/srt/models/mllama.py +1 -1
 - sglang/srt/models/mllama4.py +27 -6
 - sglang/srt/models/nemotron_h.py +511 -0
 - sglang/srt/models/olmo2.py +31 -4
 - sglang/srt/models/opt.py +5 -5
 - sglang/srt/models/phi.py +1 -1
 - sglang/srt/models/phi4mm.py +1 -1
 - sglang/srt/models/phimoe.py +0 -1
 - sglang/srt/models/pixtral.py +0 -3
 - sglang/srt/models/points_v15_chat.py +186 -0
 - sglang/srt/models/qwen.py +0 -1
 - sglang/srt/models/qwen2.py +0 -7
 - sglang/srt/models/qwen2_5_vl.py +5 -5
 - sglang/srt/models/qwen2_audio.py +2 -15
 - sglang/srt/models/qwen2_moe.py +70 -4
 - sglang/srt/models/qwen2_vl.py +6 -3
 - sglang/srt/models/qwen3.py +18 -3
 - sglang/srt/models/qwen3_moe.py +50 -38
 - sglang/srt/models/qwen3_next.py +43 -21
 - sglang/srt/models/qwen3_next_mtp.py +3 -4
 - sglang/srt/models/qwen3_omni_moe.py +661 -0
 - sglang/srt/models/qwen3_vl.py +791 -0
 - sglang/srt/models/qwen3_vl_moe.py +343 -0
 - sglang/srt/models/registry.py +15 -3
 - sglang/srt/models/roberta.py +55 -3
 - sglang/srt/models/sarashina2_vision.py +268 -0
 - sglang/srt/models/solar.py +505 -0
 - sglang/srt/models/starcoder2.py +357 -0
 - sglang/srt/models/step3_vl.py +3 -5
 - sglang/srt/models/torch_native_llama.py +9 -2
 - sglang/srt/models/utils.py +61 -0
 - sglang/srt/multimodal/processors/base_processor.py +21 -9
 - sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
 - sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
 - sglang/srt/multimodal/processors/dots_vlm.py +2 -4
 - sglang/srt/multimodal/processors/glm4v.py +1 -5
 - sglang/srt/multimodal/processors/internvl.py +20 -10
 - sglang/srt/multimodal/processors/janus_pro.py +0 -1
 - sglang/srt/multimodal/processors/mllama4.py +0 -8
 - sglang/srt/multimodal/processors/phi4mm.py +0 -1
 - sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
 - sglang/srt/multimodal/processors/qwen_vl.py +83 -17
 - sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
 - sglang/srt/multimodal/processors/step3_vl.py +1 -1
 - sglang/srt/parser/conversation.py +41 -0
 - sglang/srt/parser/jinja_template_utils.py +6 -0
 - sglang/srt/parser/reasoning_parser.py +0 -1
 - sglang/srt/sampling/custom_logit_processor.py +77 -2
 - sglang/srt/sampling/sampling_batch_info.py +36 -23
 - sglang/srt/sampling/sampling_params.py +75 -0
 - sglang/srt/server_args.py +1300 -338
 - sglang/srt/server_args_config_parser.py +146 -0
 - sglang/srt/single_batch_overlap.py +161 -0
 - sglang/srt/speculative/base_spec_worker.py +34 -0
 - sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
 - sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
 - sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
 - sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
 - sglang/srt/speculative/cpp_ngram/param.h +125 -0
 - sglang/srt/speculative/cpp_ngram/queue.h +71 -0
 - sglang/srt/speculative/draft_utils.py +226 -0
 - sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
 - sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
 - sglang/srt/speculative/eagle_info.py +786 -0
 - sglang/srt/speculative/eagle_info_v2.py +458 -0
 - sglang/srt/speculative/eagle_utils.py +113 -1270
 - sglang/srt/speculative/eagle_worker.py +120 -285
 - sglang/srt/speculative/eagle_worker_v2.py +702 -0
 - sglang/srt/speculative/ngram_info.py +433 -0
 - sglang/srt/speculative/ngram_worker.py +246 -0
 - sglang/srt/speculative/spec_info.py +49 -0
 - sglang/srt/speculative/spec_utils.py +641 -0
 - sglang/srt/speculative/standalone_worker.py +4 -14
 - sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
 - sglang/srt/tracing/trace.py +32 -6
 - sglang/srt/two_batch_overlap.py +35 -18
 - sglang/srt/utils/__init__.py +2 -0
 - sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
 - sglang/srt/{utils.py → utils/common.py} +583 -113
 - sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
 - sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
 - sglang/srt/{offloader.py → utils/offloader.py} +4 -4
 - sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
 - sglang/srt/utils/profile_merger.py +199 -0
 - sglang/srt/utils/rpd_utils.py +452 -0
 - sglang/srt/utils/slow_rank_detector.py +71 -0
 - sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
 - sglang/srt/warmup.py +8 -4
 - sglang/srt/weight_sync/utils.py +1 -1
 - sglang/test/attention/test_flashattn_backend.py +1 -1
 - sglang/test/attention/test_flashattn_mla_backend.py +0 -1
 - sglang/test/attention/test_prefix_chunk_info.py +0 -2
 - sglang/test/attention/test_trtllm_mla_backend.py +221 -53
 - sglang/test/few_shot_gsm8k_engine.py +2 -4
 - sglang/test/get_logits_ut.py +57 -0
 - sglang/test/kit_matched_stop.py +157 -0
 - sglang/test/longbench_v2/__init__.py +1 -0
 - sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
 - sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
 - sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
 - sglang/test/run_eval.py +120 -11
 - sglang/test/runners.py +3 -1
 - sglang/test/send_one.py +42 -7
 - sglang/test/simple_eval_common.py +8 -2
 - sglang/test/simple_eval_gpqa.py +0 -1
 - sglang/test/simple_eval_humaneval.py +0 -3
 - sglang/test/simple_eval_longbench_v2.py +344 -0
 - sglang/test/simple_eval_mmmu_vlm.py +441 -0
 - sglang/test/test_block_fp8.py +3 -4
 - sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
 - sglang/test/test_cutlass_moe.py +1 -2
 - sglang/test/test_cutlass_w4a8_moe.py +10 -20
 - sglang/test/test_deterministic.py +430 -0
 - sglang/test/test_deterministic_utils.py +73 -0
 - sglang/test/test_disaggregation_utils.py +93 -1
 - sglang/test/test_marlin_moe.py +0 -1
 - sglang/test/test_programs.py +1 -1
 - sglang/test/test_utils.py +432 -16
 - sglang/utils.py +10 -1
 - sglang/version.py +1 -1
 - {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
 - {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
 - sglang/srt/entrypoints/grpc_request_manager.py +0 -580
 - sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
 - sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
 - sglang/srt/mem_cache/lora_radix_cache.py +0 -421
 - sglang/srt/speculative/build_eagle_tree.py +0 -427
 - sglang/test/test_block_fp8_ep.py +0 -358
 - /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
 - /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
 - /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
 - /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
 - {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
 - {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
 - {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
 
| 
         @@ -12,7 +12,6 @@ 
     | 
|
| 
       12 
12 
     | 
    
         
             
            # limitations under the License.
         
     | 
| 
       13 
13 
     | 
    
         
             
            # ==============================================================================
         
     | 
| 
       14 
14 
     | 
    
         
             
            """Common utilities."""
         
     | 
| 
       15 
     | 
    
         
            -
             
     | 
| 
       16 
15 
     | 
    
         
             
            from __future__ import annotations
         
     | 
| 
       17 
16 
     | 
    
         | 
| 
       18 
17 
     | 
    
         
             
            import argparse
         
     | 
| 
         @@ -22,6 +21,7 @@ import ctypes 
     | 
|
| 
       22 
21 
     | 
    
         
             
            import dataclasses
         
     | 
| 
       23 
22 
     | 
    
         
             
            import functools
         
     | 
| 
       24 
23 
     | 
    
         
             
            import importlib
         
     | 
| 
      
 24 
     | 
    
         
            +
            import inspect
         
     | 
| 
       25 
25 
     | 
    
         
             
            import io
         
     | 
| 
       26 
26 
     | 
    
         
             
            import ipaddress
         
     | 
| 
       27 
27 
     | 
    
         
             
            import itertools
         
     | 
| 
         @@ -42,6 +42,7 @@ import tempfile 
     | 
|
| 
       42 
42 
     | 
    
         
             
            import threading
         
     | 
| 
       43 
43 
     | 
    
         
             
            import time
         
     | 
| 
       44 
44 
     | 
    
         
             
            import traceback
         
     | 
| 
      
 45 
     | 
    
         
            +
            import types
         
     | 
| 
       45 
46 
     | 
    
         
             
            import uuid
         
     | 
| 
       46 
47 
     | 
    
         
             
            import warnings
         
     | 
| 
       47 
48 
     | 
    
         
             
            from collections import OrderedDict, defaultdict
         
     | 
| 
         @@ -55,6 +56,7 @@ from json import JSONDecodeError 
     | 
|
| 
       55 
56 
     | 
    
         
             
            from multiprocessing.reduction import ForkingPickler
         
     | 
| 
       56 
57 
     | 
    
         
             
            from pathlib import Path
         
     | 
| 
       57 
58 
     | 
    
         
             
            from typing import (
         
     | 
| 
      
 59 
     | 
    
         
            +
                TYPE_CHECKING,
         
     | 
| 
       58 
60 
     | 
    
         
             
                Any,
         
     | 
| 
       59 
61 
     | 
    
         
             
                Callable,
         
     | 
| 
       60 
62 
     | 
    
         
             
                Dict,
         
     | 
| 
         @@ -62,6 +64,7 @@ from typing import ( 
     | 
|
| 
       62 
64 
     | 
    
         
             
                List,
         
     | 
| 
       63 
65 
     | 
    
         
             
                Optional,
         
     | 
| 
       64 
66 
     | 
    
         
             
                Protocol,
         
     | 
| 
      
 67 
     | 
    
         
            +
                Sequence,
         
     | 
| 
       65 
68 
     | 
    
         
             
                Set,
         
     | 
| 
       66 
69 
     | 
    
         
             
                Tuple,
         
     | 
| 
       67 
70 
     | 
    
         
             
                TypeVar,
         
     | 
| 
         @@ -69,6 +72,7 @@ from typing import ( 
     | 
|
| 
       69 
72 
     | 
    
         
             
            )
         
     | 
| 
       70 
73 
     | 
    
         | 
| 
       71 
74 
     | 
    
         
             
            import numpy as np
         
     | 
| 
      
 75 
     | 
    
         
            +
            import orjson
         
     | 
| 
       72 
76 
     | 
    
         
             
            import psutil
         
     | 
| 
       73 
77 
     | 
    
         
             
            import pybase64
         
     | 
| 
       74 
78 
     | 
    
         
             
            import requests
         
     | 
| 
         @@ -82,15 +86,17 @@ from packaging import version as pkg_version 
     | 
|
| 
       82 
86 
     | 
    
         
             
            from PIL import Image
         
     | 
| 
       83 
87 
     | 
    
         
             
            from starlette.routing import Mount
         
     | 
| 
       84 
88 
     | 
    
         
             
            from torch import nn
         
     | 
| 
       85 
     | 
    
         
            -
            from torch.func import functional_call
         
     | 
| 
       86 
89 
     | 
    
         
             
            from torch.library import Library
         
     | 
| 
       87 
90 
     | 
    
         
             
            from torch.profiler import ProfilerActivity, profile, record_function
         
     | 
| 
       88 
91 
     | 
    
         
             
            from torch.utils._contextlib import _DecoratorContextManager
         
     | 
| 
       89 
     | 
    
         
            -
            from triton.runtime.cache import FileCacheManager
         
     | 
| 
       90 
92 
     | 
    
         
             
            from typing_extensions import Literal
         
     | 
| 
       91 
93 
     | 
    
         | 
| 
      
 94 
     | 
    
         
            +
            from sglang.srt.environ import envs
         
     | 
| 
       92 
95 
     | 
    
         
             
            from sglang.srt.metrics.func_timer import enable_func_timer
         
     | 
| 
       93 
96 
     | 
    
         | 
| 
      
 97 
     | 
    
         
            +
            if TYPE_CHECKING:
         
     | 
| 
      
 98 
     | 
    
         
            +
                from sglang.srt.layers.quantization.base_config import QuantizeMethodBase
         
     | 
| 
      
 99 
     | 
    
         
            +
             
     | 
| 
       94 
100 
     | 
    
         
             
            logger = logging.getLogger(__name__)
         
     | 
| 
       95 
101 
     | 
    
         | 
| 
       96 
102 
     | 
    
         
             
            show_time_cost = False
         
     | 
| 
         @@ -163,18 +169,44 @@ def _check(cc_major): 
     | 
|
| 
       163 
169 
     | 
    
         
             
                ) >= (12, 3)
         
     | 
| 
       164 
170 
     | 
    
         | 
| 
       165 
171 
     | 
    
         | 
| 
      
 172 
     | 
    
         
            +
            @contextmanager
         
     | 
| 
      
 173 
     | 
    
         
            +
            def device_context(device: torch.device):
         
     | 
| 
      
 174 
     | 
    
         
            +
                if device.type == "cpu" and is_cpu():
         
     | 
| 
      
 175 
     | 
    
         
            +
                    with torch.device("cpu"):
         
     | 
| 
      
 176 
     | 
    
         
            +
                        yield
         
     | 
| 
      
 177 
     | 
    
         
            +
                else:
         
     | 
| 
      
 178 
     | 
    
         
            +
                    module = torch.get_device_module(device)
         
     | 
| 
      
 179 
     | 
    
         
            +
                    if module is not None:
         
     | 
| 
      
 180 
     | 
    
         
            +
                        with module.device(device.index):
         
     | 
| 
      
 181 
     | 
    
         
            +
                            yield
         
     | 
| 
      
 182 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 183 
     | 
    
         
            +
                        raise ValueError(f"Unknown device module: {device}")
         
     | 
| 
      
 184 
     | 
    
         
            +
             
     | 
| 
      
 185 
     | 
    
         
            +
             
     | 
| 
       166 
186 
     | 
    
         
             
            is_ampere_with_cuda_12_3 = lambda: _check(8)
         
     | 
| 
       167 
187 
     | 
    
         
             
            is_hopper_with_cuda_12_3 = lambda: _check(9)
         
     | 
| 
       168 
188 
     | 
    
         | 
| 
       169 
189 
     | 
    
         | 
| 
      
 190 
     | 
    
         
            +
            @lru_cache(maxsize=1)
         
     | 
| 
       170 
191 
     | 
    
         
             
            def is_blackwell():
         
     | 
| 
       171 
192 
     | 
    
         
             
                if not is_cuda():
         
     | 
| 
       172 
193 
     | 
    
         
             
                    return False
         
     | 
| 
       173 
194 
     | 
    
         
             
                return torch.cuda.get_device_capability()[0] == 10
         
     | 
| 
       174 
195 
     | 
    
         | 
| 
       175 
196 
     | 
    
         | 
| 
      
 197 
     | 
    
         
            +
            @lru_cache(maxsize=1)
         
     | 
| 
      
 198 
     | 
    
         
            +
            def is_sm120_supported(device=None) -> bool:
         
     | 
| 
      
 199 
     | 
    
         
            +
                if not is_cuda_alike():
         
     | 
| 
      
 200 
     | 
    
         
            +
                    return False
         
     | 
| 
      
 201 
     | 
    
         
            +
                return (torch.cuda.get_device_capability(device)[0] == 12) and (
         
     | 
| 
      
 202 
     | 
    
         
            +
                    torch.version.cuda >= "12.8"
         
     | 
| 
      
 203 
     | 
    
         
            +
                )
         
     | 
| 
      
 204 
     | 
    
         
            +
             
     | 
| 
      
 205 
     | 
    
         
            +
             
     | 
| 
       176 
206 
     | 
    
         
             
            @lru_cache(maxsize=1)
         
     | 
| 
       177 
207 
     | 
    
         
             
            def is_sm100_supported(device=None) -> bool:
         
     | 
| 
      
 208 
     | 
    
         
            +
                if not is_cuda_alike():
         
     | 
| 
      
 209 
     | 
    
         
            +
                    return False
         
     | 
| 
       178 
210 
     | 
    
         
             
                return (torch.cuda.get_device_capability(device)[0] == 10) and (
         
     | 
| 
       179 
211 
     | 
    
         
             
                    torch.version.cuda >= "12.8"
         
     | 
| 
       180 
212 
     | 
    
         
             
                )
         
     | 
| 
         @@ -182,6 +214,8 @@ def is_sm100_supported(device=None) -> bool: 
     | 
|
| 
       182 
214 
     | 
    
         | 
| 
       183 
215 
     | 
    
         
             
            @lru_cache(maxsize=1)
         
     | 
| 
       184 
216 
     | 
    
         
             
            def is_sm90_supported(device=None) -> bool:
         
     | 
| 
      
 217 
     | 
    
         
            +
                if not is_cuda_alike():
         
     | 
| 
      
 218 
     | 
    
         
            +
                    return False
         
     | 
| 
       185 
219 
     | 
    
         
             
                return (torch.cuda.get_device_capability(device)[0] == 9) and (
         
     | 
| 
       186 
220 
     | 
    
         
             
                    torch.version.cuda >= "12.3"
         
     | 
| 
       187 
221 
     | 
    
         
             
                )
         
     | 
| 
         @@ -191,6 +225,7 @@ _warned_bool_env_var_keys = set() 
     | 
|
| 
       191 
225 
     | 
    
         | 
| 
       192 
226 
     | 
    
         | 
| 
       193 
227 
     | 
    
         
             
            def get_bool_env_var(name: str, default: str = "false") -> bool:
         
     | 
| 
      
 228 
     | 
    
         
            +
                # FIXME: move your environment variable to sglang.srt.environ
         
     | 
| 
       194 
229 
     | 
    
         
             
                value = os.getenv(name, default)
         
     | 
| 
       195 
230 
     | 
    
         
             
                value = value.lower()
         
     | 
| 
       196 
231 
     | 
    
         | 
| 
         @@ -208,6 +243,7 @@ def get_bool_env_var(name: str, default: str = "false") -> bool: 
     | 
|
| 
       208 
243 
     | 
    
         | 
| 
       209 
244 
     | 
    
         | 
| 
       210 
245 
     | 
    
         
             
            def get_int_env_var(name: str, default: int = 0) -> int:
         
     | 
| 
      
 246 
     | 
    
         
            +
                # FIXME: move your environment variable to sglang.srt.environ
         
     | 
| 
       211 
247 
     | 
    
         
             
                value = os.getenv(name)
         
     | 
| 
       212 
248 
     | 
    
         
             
                if value is None or not value.strip():
         
     | 
| 
       213 
249 
     | 
    
         
             
                    return default
         
     | 
| 
         @@ -222,7 +258,7 @@ def support_triton(backend: str) -> bool: 
     | 
|
| 
       222 
258 
     | 
    
         | 
| 
       223 
259 
     | 
    
         | 
| 
       224 
260 
     | 
    
         
             
            try:
         
     | 
| 
       225 
     | 
    
         
            -
                import sgl_kernel
         
     | 
| 
      
 261 
     | 
    
         
            +
                import sgl_kernel  # noqa: F401
         
     | 
| 
       226 
262 
     | 
    
         | 
| 
       227 
263 
     | 
    
         
             
                is_intel_amx_backend_available = hasattr(
         
     | 
| 
       228 
264 
     | 
    
         
             
                    torch.ops.sgl_kernel, "convert_weight_packed"
         
     | 
| 
         @@ -247,6 +283,14 @@ def use_intel_amx_backend(layer): 
     | 
|
| 
       247 
283 
     | 
    
         
             
                return getattr(layer, "use_intel_amx_backend", False)
         
     | 
| 
       248 
284 
     | 
    
         | 
| 
       249 
285 
     | 
    
         | 
| 
      
 286 
     | 
    
         
            +
            def xpu_has_xmx_support():
         
     | 
| 
      
 287 
     | 
    
         
            +
                # TODO: update with XPU capalibity query
         
     | 
| 
      
 288 
     | 
    
         
            +
                if is_xpu():
         
     | 
| 
      
 289 
     | 
    
         
            +
                    # currently only PVC/LNL/BMG supports F64, so we only support these now
         
     | 
| 
      
 290 
     | 
    
         
            +
                    return torch.xpu.get_device_properties().has_fp64
         
     | 
| 
      
 291 
     | 
    
         
            +
                return False
         
     | 
| 
      
 292 
     | 
    
         
            +
             
     | 
| 
      
 293 
     | 
    
         
            +
             
     | 
| 
       250 
294 
     | 
    
         
             
            def is_flashinfer_available():
         
     | 
| 
       251 
295 
     | 
    
         
             
                """
         
     | 
| 
       252 
296 
     | 
    
         
             
                Check whether flashinfer is available.
         
     | 
| 
         @@ -257,6 +301,17 @@ def is_flashinfer_available(): 
     | 
|
| 
       257 
301 
     | 
    
         
             
                return importlib.util.find_spec("flashinfer") is not None and is_cuda()
         
     | 
| 
       258 
302 
     | 
    
         | 
| 
       259 
303 
     | 
    
         | 
| 
      
 304 
     | 
    
         
            +
            def is_nvidia_cublas_cu12_version_ge_12_9():
         
     | 
| 
      
 305 
     | 
    
         
            +
                """
         
     | 
| 
      
 306 
     | 
    
         
            +
                temporary fix for issue #11272
         
     | 
| 
      
 307 
     | 
    
         
            +
                """
         
     | 
| 
      
 308 
     | 
    
         
            +
                try:
         
     | 
| 
      
 309 
     | 
    
         
            +
                    installed_version = version("nvidia-cublas-cu12")
         
     | 
| 
      
 310 
     | 
    
         
            +
                except PackageNotFoundError:
         
     | 
| 
      
 311 
     | 
    
         
            +
                    return False
         
     | 
| 
      
 312 
     | 
    
         
            +
                return pkg_version.parse(installed_version) >= pkg_version.parse("12.9")
         
     | 
| 
      
 313 
     | 
    
         
            +
             
     | 
| 
      
 314 
     | 
    
         
            +
             
     | 
| 
       260 
315 
     | 
    
         
             
            def random_uuid() -> str:
         
     | 
| 
       261 
316 
     | 
    
         
             
                return str(uuid.uuid4().hex)
         
     | 
| 
       262 
317 
     | 
    
         | 
| 
         @@ -403,7 +458,15 @@ def get_available_gpu_memory( 
     | 
|
| 
       403 
458 
     | 
    
         | 
| 
       404 
459 
     | 
    
         
             
                    if empty_cache:
         
     | 
| 
       405 
460 
     | 
    
         
             
                        torch.cuda.empty_cache()
         
     | 
| 
       406 
     | 
    
         
            -
                     
     | 
| 
      
 461 
     | 
    
         
            +
                    SHARED_SYSMEM_DEVICE_MEM_SMS = (87, 110, 121)  # Orin, Thor, Spark
         
     | 
| 
      
 462 
     | 
    
         
            +
                    if get_device_sm() in SHARED_SYSMEM_DEVICE_MEM_SMS:
         
     | 
| 
      
 463 
     | 
    
         
            +
                        # On these devices, which use sysmem as device mem, torch.cuda.mem_get_info()
         
     | 
| 
      
 464 
     | 
    
         
            +
                        # only reports "free" memory, which can be lower than what is actually
         
     | 
| 
      
 465 
     | 
    
         
            +
                        # available due to not including cache memory. So we use the system available
         
     | 
| 
      
 466 
     | 
    
         
            +
                        # memory metric instead.
         
     | 
| 
      
 467 
     | 
    
         
            +
                        free_gpu_memory = psutil.virtual_memory().available
         
     | 
| 
      
 468 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 469 
     | 
    
         
            +
                        free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
         
     | 
| 
       407 
470 
     | 
    
         | 
| 
       408 
471 
     | 
    
         
             
                elif device == "xpu":
         
     | 
| 
       409 
472 
     | 
    
         
             
                    num_gpus = torch.xpu.device_count()
         
     | 
| 
         @@ -447,6 +510,8 @@ def get_available_gpu_memory( 
     | 
|
| 
       447 
510 
     | 
    
         
             
                            f"WARNING: current device is not {gpu_id}, but {torch.npu.current_device()}, ",
         
     | 
| 
       448 
511 
     | 
    
         
             
                            "which may cause useless memory allocation for torch NPU context.",
         
     | 
| 
       449 
512 
     | 
    
         
             
                        )
         
     | 
| 
      
 513 
     | 
    
         
            +
                    if empty_cache:
         
     | 
| 
      
 514 
     | 
    
         
            +
                        torch.npu.empty_cache()
         
     | 
| 
       450 
515 
     | 
    
         
             
                    free_gpu_memory, total_gpu_memory = torch.npu.mem_get_info()
         
     | 
| 
       451 
516 
     | 
    
         | 
| 
       452 
517 
     | 
    
         
             
                if distributed:
         
     | 
| 
         @@ -465,7 +530,7 @@ def is_pin_memory_available() -> bool: 
     | 
|
| 
       465 
530 
     | 
    
         | 
| 
       466 
531 
     | 
    
         
             
            class LayerFn(Protocol):
         
     | 
| 
       467 
532 
     | 
    
         | 
| 
       468 
     | 
    
         
            -
                def __call__(self,  
     | 
| 
      
 533 
     | 
    
         
            +
                def __call__(self, idx: int, prefix: str) -> torch.nn.Module: ...
         
     | 
| 
       469 
534 
     | 
    
         | 
| 
       470 
535 
     | 
    
         | 
| 
       471 
536 
     | 
    
         
             
            def make_layers(
         
     | 
| 
         @@ -475,13 +540,13 @@ def make_layers( 
     | 
|
| 
       475 
540 
     | 
    
         
             
                pp_size: Optional[int] = None,
         
     | 
| 
       476 
541 
     | 
    
         
             
                prefix: str = "",
         
     | 
| 
       477 
542 
     | 
    
         
             
                return_tuple: bool = False,
         
     | 
| 
       478 
     | 
    
         
            -
                offloader_kwargs: Dict[str, Any] =  
     | 
| 
       479 
     | 
    
         
            -
            ) -> Tuple[ 
     | 
| 
      
 543 
     | 
    
         
            +
                offloader_kwargs: Optional[Dict[str, Any]] = None,
         
     | 
| 
      
 544 
     | 
    
         
            +
            ) -> Tuple[torch.nn.Module, int, int]:
         
     | 
| 
       480 
545 
     | 
    
         
             
                """Make a list of layers with the given layer function"""
         
     | 
| 
       481 
546 
     | 
    
         
             
                # circula imports
         
     | 
| 
       482 
547 
     | 
    
         
             
                from sglang.srt.distributed import get_pp_indices
         
     | 
| 
       483 
548 
     | 
    
         
             
                from sglang.srt.layers.utils import PPMissingLayer
         
     | 
| 
       484 
     | 
    
         
            -
                from sglang.srt.offloader import get_offloader
         
     | 
| 
      
 549 
     | 
    
         
            +
                from sglang.srt.utils.offloader import get_offloader
         
     | 
| 
       485 
550 
     | 
    
         | 
| 
       486 
551 
     | 
    
         
             
                assert not pp_size or num_hidden_layers >= pp_size
         
     | 
| 
       487 
552 
     | 
    
         
             
                start_layer, end_layer = (
         
     | 
| 
         @@ -500,7 +565,7 @@ def make_layers( 
     | 
|
| 
       500 
565 
     | 
    
         
             
                            layer_fn(idx=idx, prefix=add_prefix(idx, prefix))
         
     | 
| 
       501 
566 
     | 
    
         
             
                            for idx in range(start_layer, end_layer)
         
     | 
| 
       502 
567 
     | 
    
         
             
                        ),
         
     | 
| 
       503 
     | 
    
         
            -
                        **offloader_kwargs,
         
     | 
| 
      
 568 
     | 
    
         
            +
                        **(offloader_kwargs or {}),
         
     | 
| 
       504 
569 
     | 
    
         
             
                    )
         
     | 
| 
       505 
570 
     | 
    
         
             
                    + [
         
     | 
| 
       506 
571 
     | 
    
         
             
                        PPMissingLayer(return_tuple=return_tuple)
         
     | 
| 
         @@ -512,6 +577,68 @@ def make_layers( 
     | 
|
| 
       512 
577 
     | 
    
         
             
                return modules, start_layer, end_layer
         
     | 
| 
       513 
578 
     | 
    
         | 
| 
       514 
579 
     | 
    
         | 
| 
      
 580 
     | 
    
         
            +
            def make_layers_non_pp(
         
     | 
| 
      
 581 
     | 
    
         
            +
                num_hidden_layers: int,
         
     | 
| 
      
 582 
     | 
    
         
            +
                layer_fn: LayerFn,
         
     | 
| 
      
 583 
     | 
    
         
            +
                prefix: str = "",
         
     | 
| 
      
 584 
     | 
    
         
            +
            ) -> torch.nn.ModuleList:
         
     | 
| 
      
 585 
     | 
    
         
            +
                from sglang.srt.utils.offloader import get_offloader
         
     | 
| 
      
 586 
     | 
    
         
            +
             
     | 
| 
      
 587 
     | 
    
         
            +
                layers = torch.nn.ModuleList(
         
     | 
| 
      
 588 
     | 
    
         
            +
                    get_offloader().wrap_modules(
         
     | 
| 
      
 589 
     | 
    
         
            +
                        (
         
     | 
| 
      
 590 
     | 
    
         
            +
                            layer_fn(idx=idx, prefix=add_prefix(idx, prefix))
         
     | 
| 
      
 591 
     | 
    
         
            +
                            for idx in range(num_hidden_layers)
         
     | 
| 
      
 592 
     | 
    
         
            +
                        )
         
     | 
| 
      
 593 
     | 
    
         
            +
                    )
         
     | 
| 
      
 594 
     | 
    
         
            +
                )
         
     | 
| 
      
 595 
     | 
    
         
            +
                return layers
         
     | 
| 
      
 596 
     | 
    
         
            +
             
     | 
| 
      
 597 
     | 
    
         
            +
             
     | 
| 
      
 598 
     | 
    
         
            +
            cmo_stream = None
         
     | 
| 
      
 599 
     | 
    
         
            +
             
     | 
| 
      
 600 
     | 
    
         
            +
             
     | 
| 
      
 601 
     | 
    
         
            +
            def get_cmo_stream():
         
     | 
| 
      
 602 
     | 
    
         
            +
                """
         
     | 
| 
      
 603 
     | 
    
         
            +
                Cache Management Operation(CMO).
         
     | 
| 
      
 604 
     | 
    
         
            +
                Launch a new stream to prefetch the weight of matmul when running other
         
     | 
| 
      
 605 
     | 
    
         
            +
                AIV or communication kernels, aiming to overlap the memory access time.
         
     | 
| 
      
 606 
     | 
    
         
            +
                """
         
     | 
| 
      
 607 
     | 
    
         
            +
                global cmo_stream
         
     | 
| 
      
 608 
     | 
    
         
            +
                if cmo_stream is None:
         
     | 
| 
      
 609 
     | 
    
         
            +
                    cmo_stream = torch.get_device_module().Stream()
         
     | 
| 
      
 610 
     | 
    
         
            +
                return cmo_stream
         
     | 
| 
      
 611 
     | 
    
         
            +
             
     | 
| 
      
 612 
     | 
    
         
            +
             
     | 
| 
      
 613 
     | 
    
         
            +
            def prepare_weight_cache(handle, cache):
         
     | 
| 
      
 614 
     | 
    
         
            +
                import torch_npu
         
     | 
| 
      
 615 
     | 
    
         
            +
             
     | 
| 
      
 616 
     | 
    
         
            +
                NPU_PREFETCH_MAX_SIZE_BYTES = (
         
     | 
| 
      
 617 
     | 
    
         
            +
                    1000000000  # 1GB, a large value to prefetch entire weight
         
     | 
| 
      
 618 
     | 
    
         
            +
                )
         
     | 
| 
      
 619 
     | 
    
         
            +
                stream = get_cmo_stream()
         
     | 
| 
      
 620 
     | 
    
         
            +
                stream.wait_stream(torch.npu.current_stream())
         
     | 
| 
      
 621 
     | 
    
         
            +
                with torch.npu.stream(stream):
         
     | 
| 
      
 622 
     | 
    
         
            +
                    if isinstance(cache, list):
         
     | 
| 
      
 623 
     | 
    
         
            +
                        for weight in cache:
         
     | 
| 
      
 624 
     | 
    
         
            +
                            torch_npu.npu_prefetch(
         
     | 
| 
      
 625 
     | 
    
         
            +
                                weight,
         
     | 
| 
      
 626 
     | 
    
         
            +
                                handle,
         
     | 
| 
      
 627 
     | 
    
         
            +
                                NPU_PREFETCH_MAX_SIZE_BYTES,
         
     | 
| 
      
 628 
     | 
    
         
            +
                            )
         
     | 
| 
      
 629 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 630 
     | 
    
         
            +
                        torch_npu.npu_prefetch(
         
     | 
| 
      
 631 
     | 
    
         
            +
                            cache,
         
     | 
| 
      
 632 
     | 
    
         
            +
                            handle,
         
     | 
| 
      
 633 
     | 
    
         
            +
                            NPU_PREFETCH_MAX_SIZE_BYTES,
         
     | 
| 
      
 634 
     | 
    
         
            +
                        )
         
     | 
| 
      
 635 
     | 
    
         
            +
             
     | 
| 
      
 636 
     | 
    
         
            +
             
     | 
| 
      
 637 
     | 
    
         
            +
            def wait_cmo_stream():
         
     | 
| 
      
 638 
     | 
    
         
            +
                cur_stream = torch.get_device_module().current_stream()
         
     | 
| 
      
 639 
     | 
    
         
            +
                cur_stream.wait_stream(get_cmo_stream())
         
     | 
| 
      
 640 
     | 
    
         
            +
             
     | 
| 
      
 641 
     | 
    
         
            +
             
     | 
| 
       515 
642 
     | 
    
         
             
            def set_random_seed(seed: int) -> None:
         
     | 
| 
       516 
643 
     | 
    
         
             
                """Set the random seed for all libraries."""
         
     | 
| 
       517 
644 
     | 
    
         
             
                random.seed(seed)
         
     | 
| 
         @@ -749,6 +876,25 @@ def load_image( 
     | 
|
| 
       749 
876 
     | 
    
         
             
                return image, image_size
         
     | 
| 
       750 
877 
     | 
    
         | 
| 
       751 
878 
     | 
    
         | 
| 
      
 879 
     | 
    
         
            +
            def get_image_bytes(image_file: Union[str, bytes]):
         
     | 
| 
      
 880 
     | 
    
         
            +
                if isinstance(image_file, bytes):
         
     | 
| 
      
 881 
     | 
    
         
            +
                    return image_file
         
     | 
| 
      
 882 
     | 
    
         
            +
                elif image_file.startswith("http://") or image_file.startswith("https://"):
         
     | 
| 
      
 883 
     | 
    
         
            +
                    timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
         
     | 
| 
      
 884 
     | 
    
         
            +
                    response = requests.get(image_file, timeout=timeout)
         
     | 
| 
      
 885 
     | 
    
         
            +
                    return response.content
         
     | 
| 
      
 886 
     | 
    
         
            +
                elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
         
     | 
| 
      
 887 
     | 
    
         
            +
                    with open(image_file, "rb") as f:
         
     | 
| 
      
 888 
     | 
    
         
            +
                        return f.read()
         
     | 
| 
      
 889 
     | 
    
         
            +
                elif image_file.startswith("data:"):
         
     | 
| 
      
 890 
     | 
    
         
            +
                    image_file = image_file.split(",")[1]
         
     | 
| 
      
 891 
     | 
    
         
            +
                    return pybase64.b64decode(image_file, validate=True)
         
     | 
| 
      
 892 
     | 
    
         
            +
                elif isinstance(image_file, str):
         
     | 
| 
      
 893 
     | 
    
         
            +
                    return pybase64.b64decode(image_file, validate=True)
         
     | 
| 
      
 894 
     | 
    
         
            +
                else:
         
     | 
| 
      
 895 
     | 
    
         
            +
                    raise NotImplementedError(f"Invalid image: {image_file}")
         
     | 
| 
      
 896 
     | 
    
         
            +
             
     | 
| 
      
 897 
     | 
    
         
            +
             
     | 
| 
       752 
898 
     | 
    
         
             
            def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
         
     | 
| 
       753 
899 
     | 
    
         
             
                # We import decord here to avoid a strange Segmentation fault (core dumped) issue.
         
     | 
| 
       754 
900 
     | 
    
         
             
                from decord import VideoReader, cpu, gpu
         
     | 
| 
         @@ -781,7 +927,7 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True): 
     | 
|
| 
       781 
927 
     | 
    
         
             
                            vr = VideoReader(tmp_file.name, ctx=ctx)
         
     | 
| 
       782 
928 
     | 
    
         
             
                        elif video_file.startswith("data:"):
         
     | 
| 
       783 
929 
     | 
    
         
             
                            _, encoded = video_file.split(",", 1)
         
     | 
| 
       784 
     | 
    
         
            -
                            video_bytes = pybase64.b64decode(encoded)
         
     | 
| 
      
 930 
     | 
    
         
            +
                            video_bytes = pybase64.b64decode(encoded, validate=True)
         
     | 
| 
       785 
931 
     | 
    
         
             
                            tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
         
     | 
| 
       786 
932 
     | 
    
         
             
                            tmp_file.write(video_bytes)
         
     | 
| 
       787 
933 
     | 
    
         
             
                            tmp_file.close()
         
     | 
| 
         @@ -789,7 +935,7 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True): 
     | 
|
| 
       789 
935 
     | 
    
         
             
                        elif os.path.isfile(video_file):
         
     | 
| 
       790 
936 
     | 
    
         
             
                            vr = VideoReader(video_file, ctx=ctx)
         
     | 
| 
       791 
937 
     | 
    
         
             
                        else:
         
     | 
| 
       792 
     | 
    
         
            -
                            video_bytes = pybase64.b64decode(video_file)
         
     | 
| 
      
 938 
     | 
    
         
            +
                            video_bytes = pybase64.b64decode(video_file, validate=True)
         
     | 
| 
       793 
939 
     | 
    
         
             
                            tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
         
     | 
| 
       794 
940 
     | 
    
         
             
                            tmp_file.write(video_bytes)
         
     | 
| 
       795 
941 
     | 
    
         
             
                            tmp_file.close()
         
     | 
| 
         @@ -804,6 +950,33 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True): 
     | 
|
| 
       804 
950 
     | 
    
         
             
                        os.unlink(tmp_file.name)
         
     | 
| 
       805 
951 
     | 
    
         | 
| 
       806 
952 
     | 
    
         | 
| 
      
 953 
     | 
    
         
            +
            def encode_video(video_path, frame_count_limit=None):
         
     | 
| 
      
 954 
     | 
    
         
            +
                # Lazy import because decord is not available on some arm platforms.
         
     | 
| 
      
 955 
     | 
    
         
            +
                from decord import VideoReader, cpu
         
     | 
| 
      
 956 
     | 
    
         
            +
             
     | 
| 
      
 957 
     | 
    
         
            +
                if not os.path.exists(video_path):
         
     | 
| 
      
 958 
     | 
    
         
            +
                    logger.error(f"Video {video_path} does not exist")
         
     | 
| 
      
 959 
     | 
    
         
            +
                    return []
         
     | 
| 
      
 960 
     | 
    
         
            +
             
     | 
| 
      
 961 
     | 
    
         
            +
                if frame_count_limit == 0:
         
     | 
| 
      
 962 
     | 
    
         
            +
                    return []
         
     | 
| 
      
 963 
     | 
    
         
            +
             
     | 
| 
      
 964 
     | 
    
         
            +
                def uniform_sample(l, n):
         
     | 
| 
      
 965 
     | 
    
         
            +
                    gap = len(l) / n
         
     | 
| 
      
 966 
     | 
    
         
            +
                    idxs = [int(i * gap + gap / 2) for i in range(n)]
         
     | 
| 
      
 967 
     | 
    
         
            +
                    return [l[i] for i in idxs]
         
     | 
| 
      
 968 
     | 
    
         
            +
             
     | 
| 
      
 969 
     | 
    
         
            +
                vr = VideoReader(video_path, ctx=cpu(0))
         
     | 
| 
      
 970 
     | 
    
         
            +
                sample_fps = round(vr.get_avg_fps() / 1)  # FPS
         
     | 
| 
      
 971 
     | 
    
         
            +
                frame_indices = [i for i in range(0, len(vr), sample_fps)]
         
     | 
| 
      
 972 
     | 
    
         
            +
                if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
         
     | 
| 
      
 973 
     | 
    
         
            +
                    frame_indices = uniform_sample(frame_indices, frame_count_limit)
         
     | 
| 
      
 974 
     | 
    
         
            +
             
     | 
| 
      
 975 
     | 
    
         
            +
                frames = vr.get_batch(frame_indices).asnumpy()
         
     | 
| 
      
 976 
     | 
    
         
            +
                frames = [Image.fromarray(v.astype("uint8")) for v in frames]
         
     | 
| 
      
 977 
     | 
    
         
            +
                return frames
         
     | 
| 
      
 978 
     | 
    
         
            +
             
     | 
| 
      
 979 
     | 
    
         
            +
             
     | 
| 
       807 
980 
     | 
    
         
             
            def suppress_other_loggers():
         
     | 
| 
       808 
981 
     | 
    
         
             
                warnings.filterwarnings(
         
     | 
| 
       809 
982 
     | 
    
         
             
                    "ignore", category=UserWarning, message="The given NumPy array is not writable"
         
     | 
| 
         @@ -911,7 +1084,7 @@ def monkey_patch_vllm_gguf_config(): 
     | 
|
| 
       911 
1084 
     | 
    
         | 
| 
       912 
1085 
     | 
    
         
             
                def get_quant_method_with_embedding_replaced(
         
     | 
| 
       913 
1086 
     | 
    
         
             
                    self, layer: torch.nn.Module, prefix: str
         
     | 
| 
       914 
     | 
    
         
            -
                ) -> Optional[ 
     | 
| 
      
 1087 
     | 
    
         
            +
                ) -> Optional[QuantizeMethodBase]:
         
     | 
| 
       915 
1088 
     | 
    
         
             
                    if isinstance(layer, LinearBase):
         
     | 
| 
       916 
1089 
     | 
    
         
             
                        return GGUFLinearMethod(self)
         
     | 
| 
       917 
1090 
     | 
    
         
             
                    elif isinstance(layer, VocabParallelEmbedding):
         
     | 
| 
         @@ -946,6 +1119,13 @@ def set_ulimit(target_soft_limit=65535): 
     | 
|
| 
       946 
1119 
     | 
    
         
             
                        logger.warning(f"Fail to set RLIMIT_STACK: {e}")
         
     | 
| 
       947 
1120 
     | 
    
         | 
| 
       948 
1121 
     | 
    
         | 
| 
      
 1122 
     | 
    
         
            +
            def rank0_log(msg: str):
         
     | 
| 
      
 1123 
     | 
    
         
            +
                from sglang.srt.distributed import get_tensor_model_parallel_rank
         
     | 
| 
      
 1124 
     | 
    
         
            +
             
     | 
| 
      
 1125 
     | 
    
         
            +
                if get_tensor_model_parallel_rank() == 0:
         
     | 
| 
      
 1126 
     | 
    
         
            +
                    logger.info(msg)
         
     | 
| 
      
 1127 
     | 
    
         
            +
             
     | 
| 
      
 1128 
     | 
    
         
            +
             
     | 
| 
       949 
1129 
     | 
    
         
             
            def add_api_key_middleware(app, api_key: str):
         
     | 
| 
       950 
1130 
     | 
    
         
             
                @app.middleware("http")
         
     | 
| 
       951 
1131 
     | 
    
         
             
                async def authentication(request, call_next):
         
     | 
| 
         @@ -980,7 +1160,7 @@ def configure_logger(server_args, prefix: str = ""): 
     | 
|
| 
       980 
1160 
     | 
    
         
             
                            f"{SGLANG_LOGGING_CONFIG_PATH} but it does not exist!"
         
     | 
| 
       981 
1161 
     | 
    
         
             
                        )
         
     | 
| 
       982 
1162 
     | 
    
         
             
                    with open(SGLANG_LOGGING_CONFIG_PATH, encoding="utf-8") as file:
         
     | 
| 
       983 
     | 
    
         
            -
                        custom_config =  
     | 
| 
      
 1163 
     | 
    
         
            +
                        custom_config = orjson.loads(file.read())
         
     | 
| 
       984 
1164 
     | 
    
         
             
                    logging.config.dictConfig(custom_config)
         
     | 
| 
       985 
1165 
     | 
    
         
             
                    return
         
     | 
| 
       986 
1166 
     | 
    
         
             
                format = f"[%(asctime)s{prefix}] %(message)s"
         
     | 
| 
         @@ -1159,8 +1339,46 @@ def pytorch_profile(name, func, *args, data_size=-1): 
     | 
|
| 
       1159 
1339 
     | 
    
         | 
| 
       1160 
1340 
     | 
    
         | 
| 
       1161 
1341 
     | 
    
         
             
            def get_zmq_socket(
         
     | 
| 
       1162 
     | 
    
         
            -
                context: zmq.Context, 
     | 
| 
       1163 
     | 
    
         
            -
             
     | 
| 
      
 1342 
     | 
    
         
            +
                context: zmq.Context,
         
     | 
| 
      
 1343 
     | 
    
         
            +
                socket_type: zmq.SocketType,
         
     | 
| 
      
 1344 
     | 
    
         
            +
                endpoint: Optional[str] = None,
         
     | 
| 
      
 1345 
     | 
    
         
            +
                bind: bool = True,
         
     | 
| 
      
 1346 
     | 
    
         
            +
            ) -> Union[zmq.Socket, Tuple[int, zmq.Socket]]:
         
     | 
| 
      
 1347 
     | 
    
         
            +
                """Create and configure a ZeroMQ socket.
         
     | 
| 
      
 1348 
     | 
    
         
            +
             
     | 
| 
      
 1349 
     | 
    
         
            +
                Args:
         
     | 
| 
      
 1350 
     | 
    
         
            +
                    context: ZeroMQ context to create the socket from.
         
     | 
| 
      
 1351 
     | 
    
         
            +
                    socket_type: Type of ZeroMQ socket to create.
         
     | 
| 
      
 1352 
     | 
    
         
            +
                    endpoint: Optional endpoint to bind/connect to. If None, binds to a random TCP port.
         
     | 
| 
      
 1353 
     | 
    
         
            +
                    bind: Whether to bind (True) or connect (False) to the endpoint. Ignored if endpoint is None.
         
     | 
| 
      
 1354 
     | 
    
         
            +
             
     | 
| 
      
 1355 
     | 
    
         
            +
                Returns:
         
     | 
| 
      
 1356 
     | 
    
         
            +
                    If endpoint is None: Tuple of (port, socket) where port is the randomly assigned TCP port.
         
     | 
| 
      
 1357 
     | 
    
         
            +
                    If endpoint is provided: The configured ZeroMQ socket.
         
     | 
| 
      
 1358 
     | 
    
         
            +
                """
         
     | 
| 
      
 1359 
     | 
    
         
            +
                socket = context.socket(socket_type)
         
     | 
| 
      
 1360 
     | 
    
         
            +
             
     | 
| 
      
 1361 
     | 
    
         
            +
                if endpoint is None:
         
     | 
| 
      
 1362 
     | 
    
         
            +
                    # Bind to random TCP port
         
     | 
| 
      
 1363 
     | 
    
         
            +
                    config_socket(socket, socket_type)
         
     | 
| 
      
 1364 
     | 
    
         
            +
                    port = socket.bind_to_random_port("tcp://*")
         
     | 
| 
      
 1365 
     | 
    
         
            +
                    return port, socket
         
     | 
| 
      
 1366 
     | 
    
         
            +
                else:
         
     | 
| 
      
 1367 
     | 
    
         
            +
                    # Handle IPv6 if endpoint contains brackets
         
     | 
| 
      
 1368 
     | 
    
         
            +
                    if endpoint.find("[") != -1:
         
     | 
| 
      
 1369 
     | 
    
         
            +
                        socket.setsockopt(zmq.IPV6, 1)
         
     | 
| 
      
 1370 
     | 
    
         
            +
             
     | 
| 
      
 1371 
     | 
    
         
            +
                    config_socket(socket, socket_type)
         
     | 
| 
      
 1372 
     | 
    
         
            +
             
     | 
| 
      
 1373 
     | 
    
         
            +
                    if bind:
         
     | 
| 
      
 1374 
     | 
    
         
            +
                        socket.bind(endpoint)
         
     | 
| 
      
 1375 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 1376 
     | 
    
         
            +
                        socket.connect(endpoint)
         
     | 
| 
      
 1377 
     | 
    
         
            +
             
     | 
| 
      
 1378 
     | 
    
         
            +
                    return socket
         
     | 
| 
      
 1379 
     | 
    
         
            +
             
     | 
| 
      
 1380 
     | 
    
         
            +
             
     | 
| 
      
 1381 
     | 
    
         
            +
            def config_socket(socket, socket_type: zmq.SocketType):
         
     | 
| 
       1164 
1382 
     | 
    
         
             
                mem = psutil.virtual_memory()
         
     | 
| 
       1165 
1383 
     | 
    
         
             
                total_mem = mem.total / 1024**3
         
     | 
| 
       1166 
1384 
     | 
    
         
             
                available_mem = mem.available / 1024**3
         
     | 
| 
         @@ -1169,10 +1387,6 @@ def get_zmq_socket( 
     | 
|
| 
       1169 
1387 
     | 
    
         
             
                else:
         
     | 
| 
       1170 
1388 
     | 
    
         
             
                    buf_size = -1
         
     | 
| 
       1171 
1389 
     | 
    
         | 
| 
       1172 
     | 
    
         
            -
                socket = context.socket(socket_type)
         
     | 
| 
       1173 
     | 
    
         
            -
                if endpoint.find("[") != -1:
         
     | 
| 
       1174 
     | 
    
         
            -
                    socket.setsockopt(zmq.IPV6, 1)
         
     | 
| 
       1175 
     | 
    
         
            -
             
     | 
| 
       1176 
1390 
     | 
    
         
             
                def set_send_opt():
         
     | 
| 
       1177 
1391 
     | 
    
         
             
                    socket.setsockopt(zmq.SNDHWM, 0)
         
     | 
| 
       1178 
1392 
     | 
    
         
             
                    socket.setsockopt(zmq.SNDBUF, buf_size)
         
     | 
| 
         @@ -1185,19 +1399,12 @@ def get_zmq_socket( 
     | 
|
| 
       1185 
1399 
     | 
    
         
             
                    set_send_opt()
         
     | 
| 
       1186 
1400 
     | 
    
         
             
                elif socket_type == zmq.PULL:
         
     | 
| 
       1187 
1401 
     | 
    
         
             
                    set_recv_opt()
         
     | 
| 
       1188 
     | 
    
         
            -
                elif socket_type  
     | 
| 
      
 1402 
     | 
    
         
            +
                elif socket_type in [zmq.DEALER, zmq.REQ, zmq.REP]:
         
     | 
| 
       1189 
1403 
     | 
    
         
             
                    set_send_opt()
         
     | 
| 
       1190 
1404 
     | 
    
         
             
                    set_recv_opt()
         
     | 
| 
       1191 
1405 
     | 
    
         
             
                else:
         
     | 
| 
       1192 
1406 
     | 
    
         
             
                    raise ValueError(f"Unsupported socket type: {socket_type}")
         
     | 
| 
       1193 
1407 
     | 
    
         | 
| 
       1194 
     | 
    
         
            -
                if bind:
         
     | 
| 
       1195 
     | 
    
         
            -
                    socket.bind(endpoint)
         
     | 
| 
       1196 
     | 
    
         
            -
                else:
         
     | 
| 
       1197 
     | 
    
         
            -
                    socket.connect(endpoint)
         
     | 
| 
       1198 
     | 
    
         
            -
             
     | 
| 
       1199 
     | 
    
         
            -
                return socket
         
     | 
| 
       1200 
     | 
    
         
            -
             
     | 
| 
       1201 
1408 
     | 
    
         | 
| 
       1202 
1409 
     | 
    
         
             
            def dump_to_file(dirpath, name, value):
         
     | 
| 
       1203 
1410 
     | 
    
         
             
                from sglang.srt.distributed import get_tensor_model_parallel_rank
         
     | 
| 
         @@ -1397,13 +1604,44 @@ def get_hpu_memory_capacity(): 
     | 
|
| 
       1397 
1604 
     | 
    
         | 
| 
       1398 
1605 
     | 
    
         
             
            def get_npu_memory_capacity():
         
     | 
| 
       1399 
1606 
     | 
    
         
             
                try:
         
     | 
| 
       1400 
     | 
    
         
            -
                    import torch_npu
         
     | 
| 
      
 1607 
     | 
    
         
            +
                    import torch_npu  # noqa: F401
         
     | 
| 
       1401 
1608 
     | 
    
         | 
| 
       1402 
1609 
     | 
    
         
             
                    return torch.npu.mem_get_info()[1] // 1024 // 1024  # unit: MB
         
     | 
| 
       1403 
1610 
     | 
    
         
             
                except ImportError as e:
         
     | 
| 
       1404 
1611 
     | 
    
         
             
                    raise ImportError("torch_npu is required when run on npu device.")
         
     | 
| 
       1405 
1612 
     | 
    
         | 
| 
       1406 
1613 
     | 
    
         | 
| 
      
 1614 
     | 
    
         
            +
            def get_cpu_memory_capacity():
         
     | 
| 
      
 1615 
     | 
    
         
            +
                # Per-rank memory capacity cannot be determined for customized core settings
         
     | 
| 
      
 1616 
     | 
    
         
            +
                if os.environ.get("SGLANG_CPU_OMP_THREADS_BIND", ""):
         
     | 
| 
      
 1617 
     | 
    
         
            +
                    return None
         
     | 
| 
      
 1618 
     | 
    
         
            +
                n_numa_node: int = len(get_cpu_ids_by_node())
         
     | 
| 
      
 1619 
     | 
    
         
            +
                if n_numa_node == 0:
         
     | 
| 
      
 1620 
     | 
    
         
            +
                    # Cannot determine NUMA config, fallback to total memory and avoid ZeroDivisionError.
         
     | 
| 
      
 1621 
     | 
    
         
            +
                    return float(psutil.virtual_memory().total // (1 << 20))
         
     | 
| 
      
 1622 
     | 
    
         
            +
                try:
         
     | 
| 
      
 1623 
     | 
    
         
            +
                    numa_mem_list = list()
         
     | 
| 
      
 1624 
     | 
    
         
            +
                    file_prefix = "/sys/devices/system/node/"
         
     | 
| 
      
 1625 
     | 
    
         
            +
                    for numa_id in range(n_numa_node):
         
     | 
| 
      
 1626 
     | 
    
         
            +
                        file_meminfo = f"node{numa_id}/meminfo"
         
     | 
| 
      
 1627 
     | 
    
         
            +
                        with open(os.path.join(file_prefix, file_meminfo), "r") as f:
         
     | 
| 
      
 1628 
     | 
    
         
            +
                            # MemTotal info is at the 1st line
         
     | 
| 
      
 1629 
     | 
    
         
            +
                            line = f.readline()
         
     | 
| 
      
 1630 
     | 
    
         
            +
                            # Expected format: "Node 0 MemTotal:       100000000 kB"
         
     | 
| 
      
 1631 
     | 
    
         
            +
                            parts = line.split()
         
     | 
| 
      
 1632 
     | 
    
         
            +
                            if len(parts) >= 4 and parts[2] == "MemTotal:":
         
     | 
| 
      
 1633 
     | 
    
         
            +
                                numa_mem_list.append(int(parts[3]))
         
     | 
| 
      
 1634 
     | 
    
         
            +
                            else:
         
     | 
| 
      
 1635 
     | 
    
         
            +
                                raise ValueError(f"Unexpected format in {file_meminfo}: {line}")
         
     | 
| 
      
 1636 
     | 
    
         
            +
                    # Retrieved value in KB, need MB
         
     | 
| 
      
 1637 
     | 
    
         
            +
                    numa_mem = float(min(numa_mem_list) // 1024)
         
     | 
| 
      
 1638 
     | 
    
         
            +
                    return numa_mem
         
     | 
| 
      
 1639 
     | 
    
         
            +
                except (FileNotFoundError, ValueError, IndexError):
         
     | 
| 
      
 1640 
     | 
    
         
            +
                    numa_mem = psutil.virtual_memory().total / n_numa_node
         
     | 
| 
      
 1641 
     | 
    
         
            +
                    # Retrieved value in Byte, need MB
         
     | 
| 
      
 1642 
     | 
    
         
            +
                    return float(numa_mem // (1 << 20))
         
     | 
| 
      
 1643 
     | 
    
         
            +
             
     | 
| 
      
 1644 
     | 
    
         
            +
             
     | 
| 
       1407 
1645 
     | 
    
         
             
            def get_device_memory_capacity(device: str = None):
         
     | 
| 
       1408 
1646 
     | 
    
         
             
                if is_cuda():
         
     | 
| 
       1409 
1647 
     | 
    
         
             
                    gpu_mem = get_nvgpu_memory_capacity()
         
     | 
| 
         @@ -1413,6 +1651,8 @@ def get_device_memory_capacity(device: str = None): 
     | 
|
| 
       1413 
1651 
     | 
    
         
             
                    gpu_mem = get_hpu_memory_capacity()
         
     | 
| 
       1414 
1652 
     | 
    
         
             
                elif device == "npu":
         
     | 
| 
       1415 
1653 
     | 
    
         
             
                    gpu_mem = get_npu_memory_capacity()
         
     | 
| 
      
 1654 
     | 
    
         
            +
                elif device == "cpu":
         
     | 
| 
      
 1655 
     | 
    
         
            +
                    gpu_mem = get_cpu_memory_capacity()
         
     | 
| 
       1416 
1656 
     | 
    
         
             
                else:
         
     | 
| 
       1417 
1657 
     | 
    
         
             
                    # GPU memory is not known yet or no GPU is available.
         
     | 
| 
       1418 
1658 
     | 
    
         
             
                    gpu_mem = None
         
     | 
| 
         @@ -1556,7 +1796,7 @@ def get_device(device_id: Optional[int] = None) -> str: 
     | 
|
| 
       1556 
1796 
     | 
    
         | 
| 
       1557 
1797 
     | 
    
         
             
                if is_habana_available():
         
     | 
| 
       1558 
1798 
     | 
    
         
             
                    try:
         
     | 
| 
       1559 
     | 
    
         
            -
                        import habana_frameworks.torch.hpu
         
     | 
| 
      
 1799 
     | 
    
         
            +
                        import habana_frameworks.torch.hpu  # noqa: F401
         
     | 
| 
       1560 
1800 
     | 
    
         | 
| 
       1561 
1801 
     | 
    
         
             
                        if torch.hpu.is_available():
         
     | 
| 
       1562 
1802 
     | 
    
         
             
                            if device_id == None:
         
     | 
| 
         @@ -1586,7 +1826,7 @@ def get_device_count() -> int: 
     | 
|
| 
       1586 
1826 
     | 
    
         | 
| 
       1587 
1827 
     | 
    
         
             
                if is_habana_available():
         
     | 
| 
       1588 
1828 
     | 
    
         
             
                    try:
         
     | 
| 
       1589 
     | 
    
         
            -
                        import habana_frameworks.torch.hpu
         
     | 
| 
      
 1829 
     | 
    
         
            +
                        import habana_frameworks.torch.hpu  # noqa: F401
         
     | 
| 
       1590 
1830 
     | 
    
         | 
| 
       1591 
1831 
     | 
    
         
             
                        if torch.hpu.is_available():
         
     | 
| 
       1592 
1832 
     | 
    
         
             
                            return torch.hpu.device_count()
         
     | 
| 
         @@ -1729,7 +1969,9 @@ def direct_register_custom_op( 
     | 
|
| 
       1729 
1969 
     | 
    
         
             
                    if fake_impl is not None:
         
     | 
| 
       1730 
1970 
     | 
    
         
             
                        my_lib._register_fake(op_name, fake_impl)
         
     | 
| 
       1731 
1971 
     | 
    
         
             
                except RuntimeError as error:
         
     | 
| 
       1732 
     | 
    
         
            -
                    if "Tried to register an operator" in str( 
     | 
| 
      
 1972 
     | 
    
         
            +
                    if "Tried to register an operator" in str(error) and "multiple times" in str(
         
     | 
| 
      
 1973 
     | 
    
         
            +
                        error
         
     | 
| 
      
 1974 
     | 
    
         
            +
                    ):
         
     | 
| 
       1733 
1975 
     | 
    
         
             
                        # Silently ignore duplicate registration errors
         
     | 
| 
       1734 
1976 
     | 
    
         
             
                        # This can happen in multi-engine scenarios
         
     | 
| 
       1735 
1977 
     | 
    
         
             
                        pass
         
     | 
| 
         @@ -1742,6 +1984,7 @@ def direct_register_custom_op( 
     | 
|
| 
       1742 
1984 
     | 
    
         | 
| 
       1743 
1985 
     | 
    
         | 
| 
       1744 
1986 
     | 
    
         
             
            def set_gpu_proc_affinity(
         
     | 
| 
      
 1987 
     | 
    
         
            +
                pp_size: int,
         
     | 
| 
       1745 
1988 
     | 
    
         
             
                tp_size: int,
         
     | 
| 
       1746 
1989 
     | 
    
         
             
                nnodes: int,
         
     | 
| 
       1747 
1990 
     | 
    
         
             
                gpu_id: int,
         
     | 
| 
         @@ -1750,7 +1993,8 @@ def set_gpu_proc_affinity( 
     | 
|
| 
       1750 
1993 
     | 
    
         
             
                pid = os.getpid()
         
     | 
| 
       1751 
1994 
     | 
    
         
             
                p = psutil.Process(pid)
         
     | 
| 
       1752 
1995 
     | 
    
         | 
| 
       1753 
     | 
    
         
            -
                 
     | 
| 
      
 1996 
     | 
    
         
            +
                nnodes_per_tp_group = max(nnodes // pp_size, 1)
         
     | 
| 
      
 1997 
     | 
    
         
            +
                tp_size_per_node = tp_size // nnodes_per_tp_group
         
     | 
| 
       1754 
1998 
     | 
    
         | 
| 
       1755 
1999 
     | 
    
         
             
                # total physical cores
         
     | 
| 
       1756 
2000 
     | 
    
         
             
                total_pcores = psutil.cpu_count(logical=False)
         
     | 
| 
         @@ -1862,7 +2106,7 @@ class MultiprocessingSerializer: 
     | 
|
| 
       1862 
2106 
     | 
    
         | 
| 
       1863 
2107 
     | 
    
         
             
                    if output_str:
         
     | 
| 
       1864 
2108 
     | 
    
         
             
                        # Convert bytes to base64-encoded string
         
     | 
| 
       1865 
     | 
    
         
            -
                         
     | 
| 
      
 2109 
     | 
    
         
            +
                        pybase64.b64encode(output).decode("utf-8")
         
     | 
| 
       1866 
2110 
     | 
    
         | 
| 
       1867 
2111 
     | 
    
         
             
                    return output
         
     | 
| 
       1868 
2112 
     | 
    
         | 
| 
         @@ -1951,50 +2195,6 @@ def set_uvicorn_logging_configs(): 
     | 
|
| 
       1951 
2195 
     | 
    
         
             
                LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
         
     | 
| 
       1952 
2196 
     | 
    
         | 
| 
       1953 
2197 
     | 
    
         | 
| 
       1954 
     | 
    
         
            -
            def get_ip() -> str:
         
     | 
| 
       1955 
     | 
    
         
            -
                # SGLANG_HOST_IP env can be ignore
         
     | 
| 
       1956 
     | 
    
         
            -
                host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
         
     | 
| 
       1957 
     | 
    
         
            -
                if host_ip:
         
     | 
| 
       1958 
     | 
    
         
            -
                    return host_ip
         
     | 
| 
       1959 
     | 
    
         
            -
             
     | 
| 
       1960 
     | 
    
         
            -
                # IP is not set, try to get it from the network interface
         
     | 
| 
       1961 
     | 
    
         
            -
             
     | 
| 
       1962 
     | 
    
         
            -
                # try ipv4
         
     | 
| 
       1963 
     | 
    
         
            -
                s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
         
     | 
| 
       1964 
     | 
    
         
            -
                try:
         
     | 
| 
       1965 
     | 
    
         
            -
                    s.connect(("8.8.8.8", 80))  # Doesn't need to be reachable
         
     | 
| 
       1966 
     | 
    
         
            -
                    return s.getsockname()[0]
         
     | 
| 
       1967 
     | 
    
         
            -
                except Exception:
         
     | 
| 
       1968 
     | 
    
         
            -
                    pass
         
     | 
| 
       1969 
     | 
    
         
            -
             
     | 
| 
       1970 
     | 
    
         
            -
                # try ipv6
         
     | 
| 
       1971 
     | 
    
         
            -
                try:
         
     | 
| 
       1972 
     | 
    
         
            -
                    s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
         
     | 
| 
       1973 
     | 
    
         
            -
                    # Google's public DNS server, see
         
     | 
| 
       1974 
     | 
    
         
            -
                    # https://developers.google.com/speed/public-dns/docs/using#addresses
         
     | 
| 
       1975 
     | 
    
         
            -
                    s.connect(("2001:4860:4860::8888", 80))  # Doesn't need to be reachable
         
     | 
| 
       1976 
     | 
    
         
            -
                    return s.getsockname()[0]
         
     | 
| 
       1977 
     | 
    
         
            -
                except Exception:
         
     | 
| 
       1978 
     | 
    
         
            -
                    pass
         
     | 
| 
       1979 
     | 
    
         
            -
             
     | 
| 
       1980 
     | 
    
         
            -
                # try  using hostname
         
     | 
| 
       1981 
     | 
    
         
            -
                hostname = socket.gethostname()
         
     | 
| 
       1982 
     | 
    
         
            -
                try:
         
     | 
| 
       1983 
     | 
    
         
            -
                    ip_addr = socket.gethostbyname(hostname)
         
     | 
| 
       1984 
     | 
    
         
            -
                    warnings.warn("using local ip address: {}".format(ip_addr))
         
     | 
| 
       1985 
     | 
    
         
            -
                    return ip_addr
         
     | 
| 
       1986 
     | 
    
         
            -
                except Exception:
         
     | 
| 
       1987 
     | 
    
         
            -
                    pass
         
     | 
| 
       1988 
     | 
    
         
            -
             
     | 
| 
       1989 
     | 
    
         
            -
                warnings.warn(
         
     | 
| 
       1990 
     | 
    
         
            -
                    "Failed to get the IP address, using 0.0.0.0 by default."
         
     | 
| 
       1991 
     | 
    
         
            -
                    "The value can be set by the environment variable"
         
     | 
| 
       1992 
     | 
    
         
            -
                    " SGLANG_HOST_IP or HOST_IP.",
         
     | 
| 
       1993 
     | 
    
         
            -
                    stacklevel=2,
         
     | 
| 
       1994 
     | 
    
         
            -
                )
         
     | 
| 
       1995 
     | 
    
         
            -
                return "0.0.0.0"
         
     | 
| 
       1996 
     | 
    
         
            -
             
     | 
| 
       1997 
     | 
    
         
            -
             
     | 
| 
       1998 
2198 
     | 
    
         
             
            def get_open_port() -> int:
         
     | 
| 
       1999 
2199 
     | 
    
         
             
                port = os.getenv("SGLANG_PORT")
         
     | 
| 
       2000 
2200 
     | 
    
         
             
                if port is not None:
         
     | 
| 
         @@ -2077,6 +2277,11 @@ def launch_dummy_health_check_server(host, port, enable_metrics): 
     | 
|
| 
       2077 
2277 
     | 
    
         | 
| 
       2078 
2278 
     | 
    
         
             
                app = FastAPI()
         
     | 
| 
       2079 
2279 
     | 
    
         | 
| 
      
 2280 
     | 
    
         
            +
                @app.get("/ping")
         
     | 
| 
      
 2281 
     | 
    
         
            +
                async def ping():
         
     | 
| 
      
 2282 
     | 
    
         
            +
                    """Could be used by the checkpoint-engine update script to confirm the server is up."""
         
     | 
| 
      
 2283 
     | 
    
         
            +
                    return Response(status_code=200)
         
     | 
| 
      
 2284 
     | 
    
         
            +
             
     | 
| 
       2080 
2285 
     | 
    
         
             
                @app.get("/health")
         
     | 
| 
       2081 
2286 
     | 
    
         
             
                async def health():
         
     | 
| 
       2082 
2287 
     | 
    
         
             
                    """Check the health of the http server."""
         
     | 
| 
         @@ -2199,6 +2404,8 @@ def retry( 
     | 
|
| 
       2199 
2404 
     | 
    
         
             
                    try:
         
     | 
| 
       2200 
2405 
     | 
    
         
             
                        return fn()
         
     | 
| 
       2201 
2406 
     | 
    
         
             
                    except Exception as e:
         
     | 
| 
      
 2407 
     | 
    
         
            +
                        traceback.print_exc()
         
     | 
| 
      
 2408 
     | 
    
         
            +
             
     | 
| 
       2202 
2409 
     | 
    
         
             
                        if try_index >= max_retry:
         
     | 
| 
       2203 
2410 
     | 
    
         
             
                            raise Exception(f"retry() exceed maximum number of retries.")
         
     | 
| 
       2204 
2411 
     | 
    
         | 
| 
         @@ -2212,11 +2419,30 @@ def retry( 
     | 
|
| 
       2212 
2419 
     | 
    
         
             
                        logger.warning(
         
     | 
| 
       2213 
2420 
     | 
    
         
             
                            f"retry() failed once ({try_index}th try, maximum {max_retry} retries). Will delay {delay:.2f}s and retry. Error: {e}"
         
     | 
| 
       2214 
2421 
     | 
    
         
             
                        )
         
     | 
| 
       2215 
     | 
    
         
            -
                        traceback.print_exc()
         
     | 
| 
       2216 
2422 
     | 
    
         | 
| 
       2217 
2423 
     | 
    
         
             
                        time.sleep(delay)
         
     | 
| 
       2218 
2424 
     | 
    
         | 
| 
       2219 
2425 
     | 
    
         | 
| 
      
 2426 
     | 
    
         
            +
            def has_hf_quant_config(model_path: str) -> bool:
         
     | 
| 
      
 2427 
     | 
    
         
            +
                """Check if the model path contains hf_quant_config.json file.
         
     | 
| 
      
 2428 
     | 
    
         
            +
             
     | 
| 
      
 2429 
     | 
    
         
            +
                Args:
         
     | 
| 
      
 2430 
     | 
    
         
            +
                    model_path: Path to the model, can be local path or remote URL.
         
     | 
| 
      
 2431 
     | 
    
         
            +
             
     | 
| 
      
 2432 
     | 
    
         
            +
                Returns:
         
     | 
| 
      
 2433 
     | 
    
         
            +
                    True if hf_quant_config.json exists, False otherwise.
         
     | 
| 
      
 2434 
     | 
    
         
            +
                """
         
     | 
| 
      
 2435 
     | 
    
         
            +
                if os.path.exists(os.path.join(model_path, "hf_quant_config.json")):
         
     | 
| 
      
 2436 
     | 
    
         
            +
                    return True
         
     | 
| 
      
 2437 
     | 
    
         
            +
                try:
         
     | 
| 
      
 2438 
     | 
    
         
            +
                    from huggingface_hub import HfApi
         
     | 
| 
      
 2439 
     | 
    
         
            +
             
     | 
| 
      
 2440 
     | 
    
         
            +
                    hf_api = HfApi()
         
     | 
| 
      
 2441 
     | 
    
         
            +
                    return hf_api.file_exists(model_path, "hf_quant_config.json")
         
     | 
| 
      
 2442 
     | 
    
         
            +
                except Exception:
         
     | 
| 
      
 2443 
     | 
    
         
            +
                    return False
         
     | 
| 
      
 2444 
     | 
    
         
            +
             
     | 
| 
      
 2445 
     | 
    
         
            +
             
     | 
| 
       2220 
2446 
     | 
    
         
             
            def flatten_nested_list(nested_list):
         
     | 
| 
       2221 
2447 
     | 
    
         
             
                if isinstance(nested_list, list):
         
     | 
| 
       2222 
2448 
     | 
    
         
             
                    return [
         
     | 
| 
         @@ -2251,16 +2477,9 @@ def bind_or_assign(target, source): 
     | 
|
| 
       2251 
2477 
     | 
    
         
             
                    return source
         
     | 
| 
       2252 
2478 
     | 
    
         | 
| 
       2253 
2479 
     | 
    
         | 
| 
       2254 
     | 
    
         
            -
            def  
     | 
| 
       2255 
     | 
    
         
            -
                interface  
     | 
| 
       2256 
     | 
    
         
            -
             
     | 
| 
       2257 
     | 
    
         
            -
                    get_local_ip_by_nic(interface)
         
     | 
| 
       2258 
     | 
    
         
            -
                    if interface is not None
         
     | 
| 
       2259 
     | 
    
         
            -
                    else get_local_ip_by_remote()
         
     | 
| 
       2260 
     | 
    
         
            -
                )
         
     | 
| 
       2261 
     | 
    
         
            -
             
     | 
| 
       2262 
     | 
    
         
            -
             
     | 
| 
       2263 
     | 
    
         
            -
            def get_local_ip_by_nic(interface: str) -> str:
         
     | 
| 
      
 2480 
     | 
    
         
            +
            def get_local_ip_by_nic(interface: str = None) -> Optional[str]:
         
     | 
| 
      
 2481 
     | 
    
         
            +
                if not (interface := interface or os.environ.get("SGLANG_LOCAL_IP_NIC", None)):
         
     | 
| 
      
 2482 
     | 
    
         
            +
                    return None
         
     | 
| 
       2264 
2483 
     | 
    
         
             
                try:
         
     | 
| 
       2265 
2484 
     | 
    
         
             
                    import netifaces
         
     | 
| 
       2266 
2485 
     | 
    
         
             
                except ImportError as e:
         
     | 
| 
         @@ -2281,15 +2500,13 @@ def get_local_ip_by_nic(interface: str) -> str: 
     | 
|
| 
       2281 
2500 
     | 
    
         
             
                            if ip and not ip.startswith("fe80::") and ip != "::1":
         
     | 
| 
       2282 
2501 
     | 
    
         
             
                                return ip.split("%")[0]
         
     | 
| 
       2283 
2502 
     | 
    
         
             
                except (ValueError, OSError) as e:
         
     | 
| 
       2284 
     | 
    
         
            -
                     
     | 
| 
       2285 
     | 
    
         
            -
                        "Can not get local ip from NIC. Please verify whether SGLANG_LOCAL_IP_NIC is set correctly."
         
     | 
| 
      
 2503 
     | 
    
         
            +
                    logger.warning(
         
     | 
| 
      
 2504 
     | 
    
         
            +
                        f"{e} Can not get local ip from NIC. Please verify whether SGLANG_LOCAL_IP_NIC is set correctly."
         
     | 
| 
       2286 
2505 
     | 
    
         
             
                    )
         
     | 
| 
       2287 
     | 
    
         
            -
             
     | 
| 
       2288 
     | 
    
         
            -
                # Fallback
         
     | 
| 
       2289 
     | 
    
         
            -
                return get_local_ip_by_remote()
         
     | 
| 
      
 2506 
     | 
    
         
            +
                return None
         
     | 
| 
       2290 
2507 
     | 
    
         | 
| 
       2291 
2508 
     | 
    
         | 
| 
       2292 
     | 
    
         
            -
            def get_local_ip_by_remote() -> str:
         
     | 
| 
      
 2509 
     | 
    
         
            +
            def get_local_ip_by_remote() -> Optional[str]:
         
     | 
| 
       2293 
2510 
     | 
    
         
             
                # try ipv4
         
     | 
| 
       2294 
2511 
     | 
    
         
             
                s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
         
     | 
| 
       2295 
2512 
     | 
    
         
             
                try:
         
     | 
| 
         @@ -2314,7 +2531,51 @@ def get_local_ip_by_remote() -> str: 
     | 
|
| 
       2314 
2531 
     | 
    
         
             
                    s.connect(("2001:4860:4860::8888", 80))  # Doesn't need to be reachable
         
     | 
| 
       2315 
2532 
     | 
    
         
             
                    return s.getsockname()[0]
         
     | 
| 
       2316 
2533 
     | 
    
         
             
                except Exception:
         
     | 
| 
       2317 
     | 
    
         
            -
                     
     | 
| 
      
 2534 
     | 
    
         
            +
                    logger.warning("Can not get local ip by remote")
         
     | 
| 
      
 2535 
     | 
    
         
            +
                return None
         
     | 
| 
      
 2536 
     | 
    
         
            +
             
     | 
| 
      
 2537 
     | 
    
         
            +
             
     | 
| 
      
 2538 
     | 
    
         
            +
            def get_local_ip_auto(fallback: str = None) -> str:
         
     | 
| 
      
 2539 
     | 
    
         
            +
                """
         
     | 
| 
      
 2540 
     | 
    
         
            +
                Automatically detect the local IP address using multiple fallback strategies.
         
     | 
| 
      
 2541 
     | 
    
         
            +
             
     | 
| 
      
 2542 
     | 
    
         
            +
                This function attempts to obtain the local IP address through several methods.
         
     | 
| 
      
 2543 
     | 
    
         
            +
                If all methods fail, it returns the specified fallback value or raises an exception.
         
     | 
| 
      
 2544 
     | 
    
         
            +
             
     | 
| 
      
 2545 
     | 
    
         
            +
                Args:
         
     | 
| 
      
 2546 
     | 
    
         
            +
                    fallback (str, optional): Fallback IP address to return if all detection
         
     | 
| 
      
 2547 
     | 
    
         
            +
                        methods fail. For server applications, explicitly set this to
         
     | 
| 
      
 2548 
     | 
    
         
            +
                        "0.0.0.0" (IPv4) or "::" (IPv6) to bind to all available interfaces.
         
     | 
| 
      
 2549 
     | 
    
         
            +
                        Defaults to None.
         
     | 
| 
      
 2550 
     | 
    
         
            +
             
     | 
| 
      
 2551 
     | 
    
         
            +
                Returns:
         
     | 
| 
      
 2552 
     | 
    
         
            +
                    str: The detected local IP address, or the fallback value if detection fails.
         
     | 
| 
      
 2553 
     | 
    
         
            +
             
     | 
| 
      
 2554 
     | 
    
         
            +
                Raises:
         
     | 
| 
      
 2555 
     | 
    
         
            +
                    ValueError: If IP detection fails and no fallback value is provided.
         
     | 
| 
      
 2556 
     | 
    
         
            +
             
     | 
| 
      
 2557 
     | 
    
         
            +
                Note:
         
     | 
| 
      
 2558 
     | 
    
         
            +
                    The function tries detection methods in the following order:
         
     | 
| 
      
 2559 
     | 
    
         
            +
                    1. Direct IP detection via get_ip()
         
     | 
| 
      
 2560 
     | 
    
         
            +
                    2. Network interface enumeration via get_local_ip_by_nic()
         
     | 
| 
      
 2561 
     | 
    
         
            +
                    3. Remote connection method via get_local_ip_by_remote()
         
     | 
| 
      
 2562 
     | 
    
         
            +
                """
         
     | 
| 
      
 2563 
     | 
    
         
            +
                # Try environment variable
         
     | 
| 
      
 2564 
     | 
    
         
            +
                host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
         
     | 
| 
      
 2565 
     | 
    
         
            +
                if host_ip:
         
     | 
| 
      
 2566 
     | 
    
         
            +
                    return host_ip
         
     | 
| 
      
 2567 
     | 
    
         
            +
                logger.debug("get_ip failed")
         
     | 
| 
      
 2568 
     | 
    
         
            +
                # Fallback
         
     | 
| 
      
 2569 
     | 
    
         
            +
                if ip := get_local_ip_by_nic():
         
     | 
| 
      
 2570 
     | 
    
         
            +
                    return ip
         
     | 
| 
      
 2571 
     | 
    
         
            +
                logger.debug("get_local_ip_by_nic failed")
         
     | 
| 
      
 2572 
     | 
    
         
            +
                # Fallback
         
     | 
| 
      
 2573 
     | 
    
         
            +
                if ip := get_local_ip_by_remote():
         
     | 
| 
      
 2574 
     | 
    
         
            +
                    return ip
         
     | 
| 
      
 2575 
     | 
    
         
            +
                logger.debug("get_local_ip_by_remote failed")
         
     | 
| 
      
 2576 
     | 
    
         
            +
                if fallback:
         
     | 
| 
      
 2577 
     | 
    
         
            +
                    return fallback
         
     | 
| 
      
 2578 
     | 
    
         
            +
                raise ValueError("Can not get local ip")
         
     | 
| 
       2318 
2579 
     | 
    
         | 
| 
       2319 
2580 
     | 
    
         | 
| 
       2320 
2581 
     | 
    
         
             
            def is_page_size_one(server_args):
         
     | 
| 
         @@ -2339,6 +2600,7 @@ def is_fa3_default_architecture(hf_config): 
     | 
|
| 
       2339 
2600 
     | 
    
         
             
                    "Qwen2ForCausalLM",
         
     | 
| 
       2340 
2601 
     | 
    
         
             
                    "Llama4ForConditionalGeneration",
         
     | 
| 
       2341 
2602 
     | 
    
         
             
                    "LlamaForCausalLM",
         
     | 
| 
      
 2603 
     | 
    
         
            +
                    "Olmo2ForCausalLM",
         
     | 
| 
       2342 
2604 
     | 
    
         
             
                    "Gemma2ForCausalLM",
         
     | 
| 
       2343 
2605 
     | 
    
         
             
                    "Gemma3ForConditionalGeneration",
         
     | 
| 
       2344 
2606 
     | 
    
         
             
                    "Qwen3ForCausalLM",
         
     | 
| 
         @@ -2366,15 +2628,15 @@ class BumpAllocator: 
     | 
|
| 
       2366 
2628 
     | 
    
         
             
            def log_info_on_rank0(logger, msg):
         
     | 
| 
       2367 
2629 
     | 
    
         
             
                from sglang.srt.distributed import get_tensor_model_parallel_rank
         
     | 
| 
       2368 
2630 
     | 
    
         | 
| 
       2369 
     | 
    
         
            -
                if get_tensor_model_parallel_rank() == 0:
         
     | 
| 
      
 2631 
     | 
    
         
            +
                if torch.distributed.is_initialized() and get_tensor_model_parallel_rank() == 0:
         
     | 
| 
       2370 
2632 
     | 
    
         
             
                    logger.info(msg)
         
     | 
| 
       2371 
2633 
     | 
    
         | 
| 
       2372 
2634 
     | 
    
         | 
| 
       2373 
2635 
     | 
    
         
             
            def load_json_config(data: str):
         
     | 
| 
       2374 
2636 
     | 
    
         
             
                try:
         
     | 
| 
       2375 
     | 
    
         
            -
                    return  
     | 
| 
      
 2637 
     | 
    
         
            +
                    return orjson.loads(data)
         
     | 
| 
       2376 
2638 
     | 
    
         
             
                except JSONDecodeError:
         
     | 
| 
       2377 
     | 
    
         
            -
                    return  
     | 
| 
      
 2639 
     | 
    
         
            +
                    return orjson.loads(Path(data).read_text())
         
     | 
| 
       2378 
2640 
     | 
    
         | 
| 
       2379 
2641 
     | 
    
         | 
| 
       2380 
2642 
     | 
    
         
             
            def dispose_tensor(x: torch.Tensor):
         
     | 
| 
         @@ -2496,14 +2758,6 @@ def read_system_prompt_from_file(model_name: str) -> str: 
     | 
|
| 
       2496 
2758 
     | 
    
         
             
                    return ""
         
     | 
| 
       2497 
2759 
     | 
    
         | 
| 
       2498 
2760 
     | 
    
         | 
| 
       2499 
     | 
    
         
            -
            def bind_or_assign(target, source):
         
     | 
| 
       2500 
     | 
    
         
            -
                if target is not None:
         
     | 
| 
       2501 
     | 
    
         
            -
                    target.copy_(source)
         
     | 
| 
       2502 
     | 
    
         
            -
                    return target
         
     | 
| 
       2503 
     | 
    
         
            -
                else:
         
     | 
| 
       2504 
     | 
    
         
            -
                    return source
         
     | 
| 
       2505 
     | 
    
         
            -
             
     | 
| 
       2506 
     | 
    
         
            -
             
     | 
| 
       2507 
2761 
     | 
    
         
             
            def prepack_weight_if_needed(weight):
         
     | 
| 
       2508 
2762 
     | 
    
         
             
                if weight.device != torch.device("cpu"):
         
     | 
| 
       2509 
2763 
     | 
    
         
             
                    return weight
         
     | 
| 
         @@ -2749,7 +3003,7 @@ def get_cpu_ids_by_node(): 
     | 
|
| 
       2749 
3003 
     | 
    
         
             
            def is_shm_available(dtype, world_size, local_size):
         
     | 
| 
       2750 
3004 
     | 
    
         
             
                return (
         
     | 
| 
       2751 
3005 
     | 
    
         
             
                    cpu_has_amx_support()
         
     | 
| 
       2752 
     | 
    
         
            -
                    and dtype in [torch.bfloat16, torch.float]
         
     | 
| 
      
 3006 
     | 
    
         
            +
                    and dtype in [torch.bfloat16, torch.float16, torch.float]
         
     | 
| 
       2753 
3007 
     | 
    
         
             
                    and world_size >= 1
         
     | 
| 
       2754 
3008 
     | 
    
         
             
                    and world_size == local_size
         
     | 
| 
       2755 
3009 
     | 
    
         
             
                )
         
     | 
| 
         @@ -2800,10 +3054,6 @@ def lru_cache_frozenset(maxsize=128): 
     | 
|
| 
       2800 
3054 
     | 
    
         
             
                return decorator
         
     | 
| 
       2801 
3055 
     | 
    
         | 
| 
       2802 
3056 
     | 
    
         | 
| 
       2803 
     | 
    
         
            -
            def get_origin_rid(rid):
         
     | 
| 
       2804 
     | 
    
         
            -
                return rid.split("_", 1)[1] if "_" in rid else rid
         
     | 
| 
       2805 
     | 
    
         
            -
             
     | 
| 
       2806 
     | 
    
         
            -
             
     | 
| 
       2807 
3057 
     | 
    
         
             
            def apply_module_patch(target_module, target_function, wrappers):
         
     | 
| 
       2808 
3058 
     | 
    
         
             
                original_module, original_function = parse_module_path(
         
     | 
| 
       2809 
3059 
     | 
    
         
             
                    target_module, target_function, False
         
     | 
| 
         @@ -3042,6 +3292,44 @@ def check_cuda_result(raw_output): 
     | 
|
| 
       3042 
3292 
     | 
    
         
             
                return results
         
     | 
| 
       3043 
3293 
     | 
    
         | 
| 
       3044 
3294 
     | 
    
         | 
| 
      
 3295 
     | 
    
         
            +
            def get_physical_device_id(pytorch_device_id: int) -> int:
         
     | 
| 
      
 3296 
     | 
    
         
            +
                """
         
     | 
| 
      
 3297 
     | 
    
         
            +
                Convert PyTorch logical device ID to physical device ID.
         
     | 
| 
      
 3298 
     | 
    
         
            +
                """
         
     | 
| 
      
 3299 
     | 
    
         
            +
                cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
         
     | 
| 
      
 3300 
     | 
    
         
            +
                assert (
         
     | 
| 
      
 3301 
     | 
    
         
            +
                    cuda_visible_devices is not None
         
     | 
| 
      
 3302 
     | 
    
         
            +
                ), "CUDA_VISIBLE_DEVICES should be set in a scheduler"
         
     | 
| 
      
 3303 
     | 
    
         
            +
                device_list = cuda_visible_devices.split(",")
         
     | 
| 
      
 3304 
     | 
    
         
            +
                assert (
         
     | 
| 
      
 3305 
     | 
    
         
            +
                    len(device_list) == 1
         
     | 
| 
      
 3306 
     | 
    
         
            +
                ), "CUDA_VISIBLE_DEVICES should be set to a single device in a scheduler"
         
     | 
| 
      
 3307 
     | 
    
         
            +
                return int(device_list[0])
         
     | 
| 
      
 3308 
     | 
    
         
            +
             
     | 
| 
      
 3309 
     | 
    
         
            +
             
     | 
| 
      
 3310 
     | 
    
         
            +
            def get_device_sm_nvidia_smi():
         
     | 
| 
      
 3311 
     | 
    
         
            +
                try:
         
     | 
| 
      
 3312 
     | 
    
         
            +
                    # Run nvidia-smi command and capture output
         
     | 
| 
      
 3313 
     | 
    
         
            +
                    result = subprocess.run(
         
     | 
| 
      
 3314 
     | 
    
         
            +
                        ["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader"],
         
     | 
| 
      
 3315 
     | 
    
         
            +
                        capture_output=True,
         
     | 
| 
      
 3316 
     | 
    
         
            +
                        text=True,
         
     | 
| 
      
 3317 
     | 
    
         
            +
                        check=True,
         
     | 
| 
      
 3318 
     | 
    
         
            +
                    )
         
     | 
| 
      
 3319 
     | 
    
         
            +
             
     | 
| 
      
 3320 
     | 
    
         
            +
                    # Get the first line of output (assuming at least one GPU exists)
         
     | 
| 
      
 3321 
     | 
    
         
            +
                    compute_cap_str = result.stdout.strip().split("\n")[0]
         
     | 
| 
      
 3322 
     | 
    
         
            +
             
     | 
| 
      
 3323 
     | 
    
         
            +
                    # Convert string (e.g., "9.0") to tuple of integers (9, 0)
         
     | 
| 
      
 3324 
     | 
    
         
            +
                    major, minor = map(int, compute_cap_str.split("."))
         
     | 
| 
      
 3325 
     | 
    
         
            +
                    return (major, minor)
         
     | 
| 
      
 3326 
     | 
    
         
            +
             
     | 
| 
      
 3327 
     | 
    
         
            +
                except (subprocess.CalledProcessError, FileNotFoundError, ValueError) as e:
         
     | 
| 
      
 3328 
     | 
    
         
            +
                    # Handle cases where nvidia-smi isn't available or output is unexpected
         
     | 
| 
      
 3329 
     | 
    
         
            +
                    print(f"Error getting compute capability: {e}")
         
     | 
| 
      
 3330 
     | 
    
         
            +
                    return (0, 0)  # Default/fallback value
         
     | 
| 
      
 3331 
     | 
    
         
            +
             
     | 
| 
      
 3332 
     | 
    
         
            +
             
     | 
| 
       3045 
3333 
     | 
    
         
             
            def numa_bind_to_node(node: int):
         
     | 
| 
       3046 
3334 
     | 
    
         
             
                libnuma = ctypes.CDLL("libnuma.so")
         
     | 
| 
       3047 
3335 
     | 
    
         
             
                if libnuma.numa_available() < 0:
         
     | 
| 
         @@ -3053,8 +3341,190 @@ def numa_bind_to_node(node: int): 
     | 
|
| 
       3053 
3341 
     | 
    
         | 
| 
       3054 
3342 
     | 
    
         
             
            def json_list_type(value):
         
     | 
| 
       3055 
3343 
     | 
    
         
             
                try:
         
     | 
| 
       3056 
     | 
    
         
            -
                    return  
     | 
| 
      
 3344 
     | 
    
         
            +
                    return orjson.loads(value)
         
     | 
| 
       3057 
3345 
     | 
    
         
             
                except json.JSONDecodeError:
         
     | 
| 
       3058 
3346 
     | 
    
         
             
                    raise argparse.ArgumentTypeError(
         
     | 
| 
       3059 
3347 
     | 
    
         
             
                        f"Invalid JSON list: {value}. Please provide a valid JSON list."
         
     | 
| 
       3060 
3348 
     | 
    
         
             
                    )
         
     | 
| 
      
 3349 
     | 
    
         
            +
             
     | 
| 
      
 3350 
     | 
    
         
            +
             
     | 
| 
      
 3351 
     | 
    
         
            +
            @contextmanager
         
     | 
| 
      
 3352 
     | 
    
         
            +
            def maybe_reindex_device_id(gpu_id: int):
         
     | 
| 
      
 3353 
     | 
    
         
            +
             
     | 
| 
      
 3354 
     | 
    
         
            +
                if envs.SGLANG_ONE_VISIBLE_DEVICE_PER_PROCESS.get() is False or not is_cuda_alike():
         
     | 
| 
      
 3355 
     | 
    
         
            +
                    yield gpu_id
         
     | 
| 
      
 3356 
     | 
    
         
            +
                    return
         
     | 
| 
      
 3357 
     | 
    
         
            +
             
     | 
| 
      
 3358 
     | 
    
         
            +
                original_cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
         
     | 
| 
      
 3359 
     | 
    
         
            +
                if original_cuda_visible_devices:
         
     | 
| 
      
 3360 
     | 
    
         
            +
                    cuda_visible_devices = original_cuda_visible_devices.split(",")
         
     | 
| 
      
 3361 
     | 
    
         
            +
                else:
         
     | 
| 
      
 3362 
     | 
    
         
            +
                    cuda_visible_devices = []
         
     | 
| 
      
 3363 
     | 
    
         
            +
             
     | 
| 
      
 3364 
     | 
    
         
            +
                str_gpu_id = cuda_visible_devices[gpu_id] if cuda_visible_devices else str(gpu_id)
         
     | 
| 
      
 3365 
     | 
    
         
            +
                os.environ["CUDA_VISIBLE_DEVICES"] = str_gpu_id
         
     | 
| 
      
 3366 
     | 
    
         
            +
             
     | 
| 
      
 3367 
     | 
    
         
            +
                logger.debug(f"Set CUDA_VISIBLE_DEVICES to {str_gpu_id}")
         
     | 
| 
      
 3368 
     | 
    
         
            +
             
     | 
| 
      
 3369 
     | 
    
         
            +
                yield 0
         
     | 
| 
      
 3370 
     | 
    
         
            +
             
     | 
| 
      
 3371 
     | 
    
         
            +
                if original_cuda_visible_devices:
         
     | 
| 
      
 3372 
     | 
    
         
            +
                    os.environ["CUDA_VISIBLE_DEVICES"] = original_cuda_visible_devices
         
     | 
| 
      
 3373 
     | 
    
         
            +
                else:
         
     | 
| 
      
 3374 
     | 
    
         
            +
                    del os.environ["CUDA_VISIBLE_DEVICES"]
         
     | 
| 
      
 3375 
     | 
    
         
            +
             
     | 
| 
      
 3376 
     | 
    
         
            +
             
     | 
| 
      
 3377 
     | 
    
         
            +
            def get_extend_input_len_swa_limit(
         
     | 
| 
      
 3378 
     | 
    
         
            +
                sliding_window_size: int, chunked_prefill_size: int, page_size: int
         
     | 
| 
      
 3379 
     | 
    
         
            +
            ) -> int:
         
     | 
| 
      
 3380 
     | 
    
         
            +
                # 1. a factor of 2x is because each prefill contains chunked_prefill_size tokens,
         
     | 
| 
      
 3381 
     | 
    
         
            +
                #    and between prefills, we run swa_radix_cache.cache_unfinished_req(),
         
     | 
| 
      
 3382 
     | 
    
         
            +
                #    so we unlock the previously locked nodes.
         
     | 
| 
      
 3383 
     | 
    
         
            +
                # 2. max is to handle the case that chunked_prefill_size is larger than sliding_window_size.
         
     | 
| 
      
 3384 
     | 
    
         
            +
                #    in that case, each prefill contains chunked_prefill_size tokens,
         
     | 
| 
      
 3385 
     | 
    
         
            +
                #    and we can only free out-of-sliding-window kv indices after each prefill.
         
     | 
| 
      
 3386 
     | 
    
         
            +
                # 3. page_size is because we want to have 1 token extra for generated tokens.
         
     | 
| 
      
 3387 
     | 
    
         
            +
                return page_size + 2 * max(sliding_window_size, chunked_prefill_size)
         
     | 
| 
      
 3388 
     | 
    
         
            +
             
     | 
| 
      
 3389 
     | 
    
         
            +
             
     | 
| 
      
 3390 
     | 
    
         
            +
            def get_num_new_pages(
         
     | 
| 
      
 3391 
     | 
    
         
            +
                seq_lens: torch.Tensor,
         
     | 
| 
      
 3392 
     | 
    
         
            +
                page_size: int,
         
     | 
| 
      
 3393 
     | 
    
         
            +
                prefix_lens: Optional[torch.Tensor] = None,
         
     | 
| 
      
 3394 
     | 
    
         
            +
                decode: bool = False,
         
     | 
| 
      
 3395 
     | 
    
         
            +
            ) -> torch.Tensor:
         
     | 
| 
      
 3396 
     | 
    
         
            +
                """
         
     | 
| 
      
 3397 
     | 
    
         
            +
                Get the number of new pages for the given prefix and sequence lengths.
         
     | 
| 
      
 3398 
     | 
    
         
            +
                We use cpu tensors to avoid blocking kernel launch.
         
     | 
| 
      
 3399 
     | 
    
         
            +
                """
         
     | 
| 
      
 3400 
     | 
    
         
            +
                cpu_device = torch.device("cpu")
         
     | 
| 
      
 3401 
     | 
    
         
            +
                assert seq_lens.device == cpu_device
         
     | 
| 
      
 3402 
     | 
    
         
            +
             
     | 
| 
      
 3403 
     | 
    
         
            +
                if prefix_lens is None or decode:
         
     | 
| 
      
 3404 
     | 
    
         
            +
                    # NOTE: Special case for handling decode, which prefix lens is `seq_lens - 1`.
         
     | 
| 
      
 3405 
     | 
    
         
            +
                    assert decode
         
     | 
| 
      
 3406 
     | 
    
         
            +
                    return (seq_lens % page_size == 1).int().sum().item()
         
     | 
| 
      
 3407 
     | 
    
         
            +
             
     | 
| 
      
 3408 
     | 
    
         
            +
                assert prefix_lens.device == cpu_device
         
     | 
| 
      
 3409 
     | 
    
         
            +
                num_pages_after = (seq_lens + page_size - 1) // page_size
         
     | 
| 
      
 3410 
     | 
    
         
            +
                num_pages_before = (prefix_lens + page_size - 1) // page_size
         
     | 
| 
      
 3411 
     | 
    
         
            +
                num_new_pages = num_pages_after - num_pages_before
         
     | 
| 
      
 3412 
     | 
    
         
            +
                sum_num_new_pages = torch.sum(num_new_pages).to(torch.int64)
         
     | 
| 
      
 3413 
     | 
    
         
            +
                return sum_num_new_pages.item()
         
     | 
| 
      
 3414 
     | 
    
         
            +
             
     | 
| 
      
 3415 
     | 
    
         
            +
             
     | 
| 
      
 3416 
     | 
    
         
            +
            class CachedKernel:
         
     | 
| 
      
 3417 
     | 
    
         
            +
                """
         
     | 
| 
      
 3418 
     | 
    
         
            +
                Wrapper that allows kernel[grid](...) syntax with caching based on a key function.
         
     | 
| 
      
 3419 
     | 
    
         
            +
             
     | 
| 
      
 3420 
     | 
    
         
            +
                This wrapper caches compiled Triton kernels based on keys extracted by a
         
     | 
| 
      
 3421 
     | 
    
         
            +
                user-provided key function to avoid redundant compilations.
         
     | 
| 
      
 3422 
     | 
    
         
            +
                """
         
     | 
| 
      
 3423 
     | 
    
         
            +
             
     | 
| 
      
 3424 
     | 
    
         
            +
                def __init__(self, fn, key_fn=None):
         
     | 
| 
      
 3425 
     | 
    
         
            +
                    self.fn = fn
         
     | 
| 
      
 3426 
     | 
    
         
            +
                    assert isinstance(fn, triton.runtime.jit.JITFunction)
         
     | 
| 
      
 3427 
     | 
    
         
            +
             
     | 
| 
      
 3428 
     | 
    
         
            +
                    original_fn = fn.fn
         
     | 
| 
      
 3429 
     | 
    
         
            +
                    self.signature = inspect.signature(original_fn)
         
     | 
| 
      
 3430 
     | 
    
         
            +
                    self.param_names = tuple(self.signature.parameters.keys())
         
     | 
| 
      
 3431 
     | 
    
         
            +
                    self.num_args = len(self.param_names)
         
     | 
| 
      
 3432 
     | 
    
         
            +
             
     | 
| 
      
 3433 
     | 
    
         
            +
                    # Check that no parameters have default values
         
     | 
| 
      
 3434 
     | 
    
         
            +
                    for name, param in self.signature.parameters.items():
         
     | 
| 
      
 3435 
     | 
    
         
            +
                        assert (
         
     | 
| 
      
 3436 
     | 
    
         
            +
                            param.default is inspect.Parameter.empty
         
     | 
| 
      
 3437 
     | 
    
         
            +
                        ), f"Parameter '{name}' has a default value. Default parameters are not supported in cached kernels."
         
     | 
| 
      
 3438 
     | 
    
         
            +
             
     | 
| 
      
 3439 
     | 
    
         
            +
                    functools.update_wrapper(self, original_fn)
         
     | 
| 
      
 3440 
     | 
    
         
            +
                    self.kernel_cache = {}
         
     | 
| 
      
 3441 
     | 
    
         
            +
             
     | 
| 
      
 3442 
     | 
    
         
            +
                    # Store the key function
         
     | 
| 
      
 3443 
     | 
    
         
            +
                    self.key_fn = key_fn
         
     | 
| 
      
 3444 
     | 
    
         
            +
             
     | 
| 
      
 3445 
     | 
    
         
            +
                def __getitem__(self, grid):
         
     | 
| 
      
 3446 
     | 
    
         
            +
                    """
         
     | 
| 
      
 3447 
     | 
    
         
            +
                    Index with grid to get a launcher function.
         
     | 
| 
      
 3448 
     | 
    
         
            +
                    Returns a launcher that will handle caching based on the key function.
         
     | 
| 
      
 3449 
     | 
    
         
            +
                    """
         
     | 
| 
      
 3450 
     | 
    
         
            +
                    assert (
         
     | 
| 
      
 3451 
     | 
    
         
            +
                        isinstance(grid, tuple) and len(grid) <= 3
         
     | 
| 
      
 3452 
     | 
    
         
            +
                    ), "Grid must be a tuple with at most 3 dimensions."
         
     | 
| 
      
 3453 
     | 
    
         
            +
             
     | 
| 
      
 3454 
     | 
    
         
            +
                    # Normalize grid once
         
     | 
| 
      
 3455 
     | 
    
         
            +
                    if len(grid) < 3:
         
     | 
| 
      
 3456 
     | 
    
         
            +
                        grid = grid + (1,) * (3 - len(grid))
         
     | 
| 
      
 3457 
     | 
    
         
            +
             
     | 
| 
      
 3458 
     | 
    
         
            +
                    def launcher(*args, **kwargs):
         
     | 
| 
      
 3459 
     | 
    
         
            +
                        cache_key = self.key_fn(args, kwargs)
         
     | 
| 
      
 3460 
     | 
    
         
            +
             
     | 
| 
      
 3461 
     | 
    
         
            +
                        cached_kernel = self.kernel_cache.get(cache_key)
         
     | 
| 
      
 3462 
     | 
    
         
            +
             
     | 
| 
      
 3463 
     | 
    
         
            +
                        if cached_kernel is None:
         
     | 
| 
      
 3464 
     | 
    
         
            +
                            # First time: compile and cache the kernel
         
     | 
| 
      
 3465 
     | 
    
         
            +
                            cached_kernel = self.fn[grid](*args, **kwargs)
         
     | 
| 
      
 3466 
     | 
    
         
            +
                            self.kernel_cache[cache_key] = cached_kernel
         
     | 
| 
      
 3467 
     | 
    
         
            +
                            return cached_kernel
         
     | 
| 
      
 3468 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 3469 
     | 
    
         
            +
                            # Use cached kernel
         
     | 
| 
      
 3470 
     | 
    
         
            +
                            all_args = self._build_args(args, kwargs)
         
     | 
| 
      
 3471 
     | 
    
         
            +
                            cached_kernel[grid](*all_args)
         
     | 
| 
      
 3472 
     | 
    
         
            +
                            return cached_kernel
         
     | 
| 
      
 3473 
     | 
    
         
            +
             
     | 
| 
      
 3474 
     | 
    
         
            +
                    return launcher
         
     | 
| 
      
 3475 
     | 
    
         
            +
             
     | 
| 
      
 3476 
     | 
    
         
            +
                def _build_args(self, args, kwargs):
         
     | 
| 
      
 3477 
     | 
    
         
            +
                    """
         
     | 
| 
      
 3478 
     | 
    
         
            +
                    Build the complete argument list for kernel invocation.
         
     | 
| 
      
 3479 
     | 
    
         
            +
                    """
         
     | 
| 
      
 3480 
     | 
    
         
            +
                    complete_args = list(args)
         
     | 
| 
      
 3481 
     | 
    
         
            +
             
     | 
| 
      
 3482 
     | 
    
         
            +
                    for i in range(len(args), self.num_args):
         
     | 
| 
      
 3483 
     | 
    
         
            +
                        name = self.param_names[i]
         
     | 
| 
      
 3484 
     | 
    
         
            +
                        value = kwargs.get(name, inspect.Parameter.empty)
         
     | 
| 
      
 3485 
     | 
    
         
            +
                        if value is not inspect.Parameter.empty:
         
     | 
| 
      
 3486 
     | 
    
         
            +
                            complete_args.append(value)
         
     | 
| 
      
 3487 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 3488 
     | 
    
         
            +
                            raise ValueError(f"Missing argument: {name}")
         
     | 
| 
      
 3489 
     | 
    
         
            +
             
     | 
| 
      
 3490 
     | 
    
         
            +
                    return complete_args
         
     | 
| 
      
 3491 
     | 
    
         
            +
             
     | 
| 
      
 3492 
     | 
    
         
            +
                def _clear_cache(self):
         
     | 
| 
      
 3493 
     | 
    
         
            +
                    """
         
     | 
| 
      
 3494 
     | 
    
         
            +
                    Clear the kernel cache for testing purposes.
         
     | 
| 
      
 3495 
     | 
    
         
            +
                    """
         
     | 
| 
      
 3496 
     | 
    
         
            +
                    self.kernel_cache.clear()
         
     | 
| 
      
 3497 
     | 
    
         
            +
             
     | 
| 
      
 3498 
     | 
    
         
            +
             
     | 
| 
      
 3499 
     | 
    
         
            +
            def cached_triton_kernel(key_fn=None):
         
     | 
| 
      
 3500 
     | 
    
         
            +
                """
         
     | 
| 
      
 3501 
     | 
    
         
            +
                Decorator that enables key-based caching for Triton kernels using a key function.
         
     | 
| 
      
 3502 
     | 
    
         
            +
             
     | 
| 
      
 3503 
     | 
    
         
            +
                It essentially bypasses Triton's built-in caching mechanism, allowing users to
         
     | 
| 
      
 3504 
     | 
    
         
            +
                define their own caching strategy based on kernel parameters. This helps reduce
         
     | 
| 
      
 3505 
     | 
    
         
            +
                the heavy overheads of Triton kernel launch when the kernel specialization dispatch
         
     | 
| 
      
 3506 
     | 
    
         
            +
                is simple.
         
     | 
| 
      
 3507 
     | 
    
         
            +
             
     | 
| 
      
 3508 
     | 
    
         
            +
                Usage:
         
     | 
| 
      
 3509 
     | 
    
         
            +
                    @cached_triton_kernel(key_fn=lambda args, kwargs: kwargs.get('BLOCK_SIZE', 1024))
         
     | 
| 
      
 3510 
     | 
    
         
            +
                    @triton.jit
         
     | 
| 
      
 3511 
     | 
    
         
            +
                    def my_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr):
         
     | 
| 
      
 3512 
     | 
    
         
            +
                        ...
         
     | 
| 
      
 3513 
     | 
    
         
            +
             
     | 
| 
      
 3514 
     | 
    
         
            +
                    # Invoke normally
         
     | 
| 
      
 3515 
     | 
    
         
            +
                    my_kernel[grid](x, y, BLOCK_SIZE=1024)
         
     | 
| 
      
 3516 
     | 
    
         
            +
             
     | 
| 
      
 3517 
     | 
    
         
            +
                Args:
         
     | 
| 
      
 3518 
     | 
    
         
            +
                    key_fn: A function that takes (args, kwargs) and returns the cache key(s).
         
     | 
| 
      
 3519 
     | 
    
         
            +
                            The key can be a single value or a tuple of values.
         
     | 
| 
      
 3520 
     | 
    
         
            +
             
     | 
| 
      
 3521 
     | 
    
         
            +
                Returns:
         
     | 
| 
      
 3522 
     | 
    
         
            +
                    A decorator that wraps the kernel with caching functionality.
         
     | 
| 
      
 3523 
     | 
    
         
            +
             
     | 
| 
      
 3524 
     | 
    
         
            +
                Note: Kernels with default parameter values are not supported and will raise an assertion error.
         
     | 
| 
      
 3525 
     | 
    
         
            +
                """
         
     | 
| 
      
 3526 
     | 
    
         
            +
             
     | 
| 
      
 3527 
     | 
    
         
            +
                def decorator(fn):
         
     | 
| 
      
 3528 
     | 
    
         
            +
                    return CachedKernel(fn, key_fn)
         
     | 
| 
      
 3529 
     | 
    
         
            +
             
     | 
| 
      
 3530 
     | 
    
         
            +
                return decorator
         
     |