sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -11
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +474 -142
- sglang/compile_deep_gemm.py +3 -0
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +10 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +314 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +228 -92
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +294 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +78 -37
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +373 -68
- sglang/srt/disaggregation/prefill.py +53 -49
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +842 -0
- sglang/srt/entrypoints/grpc_server.py +950 -0
- sglang/srt/entrypoints/http_server.py +179 -60
- sglang/srt/entrypoints/openai/protocol.py +265 -29
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +213 -122
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +289 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +17 -8
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +215 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +40 -8
- sglang/srt/layers/attention/flashinfer_backend.py +341 -204
- sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
- sglang/srt/layers/attention/mamba/mamba.py +577 -0
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +180 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
- sglang/srt/layers/moe/ep_moe/layer.py +248 -333
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +83 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +29 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +155 -60
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +191 -56
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +28 -33
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +44 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +55 -0
- sglang/srt/managers/schedule_batch.py +343 -212
- sglang/srt/managers/schedule_policy.py +145 -18
- sglang/srt/managers/scheduler.py +653 -273
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +579 -674
- sglang/srt/managers/tp_worker.py +96 -26
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +9 -2
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +651 -80
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +227 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +93 -48
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +74 -46
- sglang/srt/model_executor/model_runner.py +455 -176
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +10 -4
- sglang/srt/model_loader/loader.py +319 -10
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +161 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +578 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +50 -4
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +55 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +49 -26
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1051 -285
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +98 -29
- sglang/srt/speculative/ngram_info.py +428 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +605 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +451 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +119 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +313 -0
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +140 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +407 -8
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -15,12 +15,14 @@
|
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
|
+
import argparse
|
18
19
|
import asyncio
|
19
20
|
import builtins
|
20
21
|
import ctypes
|
21
22
|
import dataclasses
|
22
23
|
import functools
|
23
24
|
import importlib
|
25
|
+
import inspect
|
24
26
|
import io
|
25
27
|
import ipaddress
|
26
28
|
import itertools
|
@@ -81,11 +83,9 @@ from packaging import version as pkg_version
|
|
81
83
|
from PIL import Image
|
82
84
|
from starlette.routing import Mount
|
83
85
|
from torch import nn
|
84
|
-
from torch.func import functional_call
|
85
86
|
from torch.library import Library
|
86
87
|
from torch.profiler import ProfilerActivity, profile, record_function
|
87
88
|
from torch.utils._contextlib import _DecoratorContextManager
|
88
|
-
from triton.runtime.cache import FileCacheManager
|
89
89
|
from typing_extensions import Literal
|
90
90
|
|
91
91
|
from sglang.srt.metrics.func_timer import enable_func_timer
|
@@ -166,6 +166,7 @@ is_ampere_with_cuda_12_3 = lambda: _check(8)
|
|
166
166
|
is_hopper_with_cuda_12_3 = lambda: _check(9)
|
167
167
|
|
168
168
|
|
169
|
+
@lru_cache(maxsize=1)
|
169
170
|
def is_blackwell():
|
170
171
|
if not is_cuda():
|
171
172
|
return False
|
@@ -174,6 +175,8 @@ def is_blackwell():
|
|
174
175
|
|
175
176
|
@lru_cache(maxsize=1)
|
176
177
|
def is_sm100_supported(device=None) -> bool:
|
178
|
+
if not is_cuda_alike():
|
179
|
+
return False
|
177
180
|
return (torch.cuda.get_device_capability(device)[0] == 10) and (
|
178
181
|
torch.version.cuda >= "12.8"
|
179
182
|
)
|
@@ -181,6 +184,8 @@ def is_sm100_supported(device=None) -> bool:
|
|
181
184
|
|
182
185
|
@lru_cache(maxsize=1)
|
183
186
|
def is_sm90_supported(device=None) -> bool:
|
187
|
+
if not is_cuda_alike():
|
188
|
+
return False
|
184
189
|
return (torch.cuda.get_device_capability(device)[0] == 9) and (
|
185
190
|
torch.version.cuda >= "12.3"
|
186
191
|
)
|
@@ -190,6 +195,7 @@ _warned_bool_env_var_keys = set()
|
|
190
195
|
|
191
196
|
|
192
197
|
def get_bool_env_var(name: str, default: str = "false") -> bool:
|
198
|
+
# FIXME: move your environment variable to sglang.srt.environ
|
193
199
|
value = os.getenv(name, default)
|
194
200
|
value = value.lower()
|
195
201
|
|
@@ -207,6 +213,7 @@ def get_bool_env_var(name: str, default: str = "false") -> bool:
|
|
207
213
|
|
208
214
|
|
209
215
|
def get_int_env_var(name: str, default: int = 0) -> int:
|
216
|
+
# FIXME: move your environment variable to sglang.srt.environ
|
210
217
|
value = os.getenv(name)
|
211
218
|
if value is None or not value.strip():
|
212
219
|
return default
|
@@ -230,8 +237,16 @@ except:
|
|
230
237
|
is_intel_amx_backend_available = False
|
231
238
|
|
232
239
|
|
240
|
+
try:
|
241
|
+
# move torch._C._cpu._is_amx_tile_supported() from cpu_has_amx_support
|
242
|
+
# to support torch compile
|
243
|
+
is_amx_tile_supported = torch._C._cpu._is_amx_tile_supported()
|
244
|
+
except:
|
245
|
+
is_amx_tile_supported = False
|
246
|
+
|
247
|
+
|
233
248
|
def cpu_has_amx_support():
|
234
|
-
return
|
249
|
+
return is_amx_tile_supported and is_intel_amx_backend_available
|
235
250
|
|
236
251
|
|
237
252
|
def use_intel_amx_backend(layer):
|
@@ -426,7 +441,9 @@ def get_available_gpu_memory(
|
|
426
441
|
|
427
442
|
elif device == "cpu":
|
428
443
|
# TODO: rename the variables in the current function to be not GPU specific
|
429
|
-
|
444
|
+
total_free_memory = psutil.virtual_memory().available
|
445
|
+
n_numa_node: int = len(get_cpu_ids_by_node())
|
446
|
+
free_gpu_memory = round(total_free_memory / n_numa_node, 3)
|
430
447
|
elif device == "npu":
|
431
448
|
num_gpus = torch.npu.device_count()
|
432
449
|
assert gpu_id < num_gpus
|
@@ -454,7 +471,7 @@ def is_pin_memory_available() -> bool:
|
|
454
471
|
|
455
472
|
class LayerFn(Protocol):
|
456
473
|
|
457
|
-
def __call__(self,
|
474
|
+
def __call__(self, idx: int, prefix: str) -> torch.nn.Module: ...
|
458
475
|
|
459
476
|
|
460
477
|
def make_layers(
|
@@ -465,7 +482,7 @@ def make_layers(
|
|
465
482
|
prefix: str = "",
|
466
483
|
return_tuple: bool = False,
|
467
484
|
offloader_kwargs: Dict[str, Any] = {},
|
468
|
-
) -> Tuple[
|
485
|
+
) -> Tuple[torch.nn.Module, int, int]:
|
469
486
|
"""Make a list of layers with the given layer function"""
|
470
487
|
# circula imports
|
471
488
|
from sglang.srt.distributed import get_pp_indices
|
@@ -501,6 +518,68 @@ def make_layers(
|
|
501
518
|
return modules, start_layer, end_layer
|
502
519
|
|
503
520
|
|
521
|
+
def make_layers_non_pp(
|
522
|
+
num_hidden_layers: int,
|
523
|
+
layer_fn: LayerFn,
|
524
|
+
prefix: str = "",
|
525
|
+
) -> torch.nn.ModuleList:
|
526
|
+
from sglang.srt.offloader import get_offloader
|
527
|
+
|
528
|
+
layers = torch.nn.ModuleList(
|
529
|
+
get_offloader().wrap_modules(
|
530
|
+
(
|
531
|
+
layer_fn(idx=idx, prefix=add_prefix(idx, prefix))
|
532
|
+
for idx in range(num_hidden_layers)
|
533
|
+
)
|
534
|
+
)
|
535
|
+
)
|
536
|
+
return layers
|
537
|
+
|
538
|
+
|
539
|
+
cmo_stream = None
|
540
|
+
|
541
|
+
|
542
|
+
def get_cmo_stream():
|
543
|
+
"""
|
544
|
+
Cache Management Operation(CMO).
|
545
|
+
Launch a new stream to prefetch the weight of matmul when running other
|
546
|
+
AIV or communication kernels, aiming to overlap the memory access time.
|
547
|
+
"""
|
548
|
+
global cmo_stream
|
549
|
+
if cmo_stream is None:
|
550
|
+
cmo_stream = torch.get_device_module().Stream()
|
551
|
+
return cmo_stream
|
552
|
+
|
553
|
+
|
554
|
+
def prepare_weight_cache(handle, cache):
|
555
|
+
import torch_npu
|
556
|
+
|
557
|
+
NPU_PREFETCH_MAX_SIZE_BYTES = (
|
558
|
+
1000000000 # 1GB, a large value to prefetch entire weight
|
559
|
+
)
|
560
|
+
stream = get_cmo_stream()
|
561
|
+
stream.wait_stream(torch.npu.current_stream())
|
562
|
+
with torch.npu.stream(stream):
|
563
|
+
if isinstance(cache, list):
|
564
|
+
for weight in cache:
|
565
|
+
torch_npu.npu_prefetch(
|
566
|
+
weight,
|
567
|
+
handle,
|
568
|
+
NPU_PREFETCH_MAX_SIZE_BYTES,
|
569
|
+
)
|
570
|
+
else:
|
571
|
+
torch_npu.npu_prefetch(
|
572
|
+
cache,
|
573
|
+
handle,
|
574
|
+
NPU_PREFETCH_MAX_SIZE_BYTES,
|
575
|
+
)
|
576
|
+
|
577
|
+
|
578
|
+
def wait_cmo_stream():
|
579
|
+
cur_stream = torch.get_device_module().current_stream()
|
580
|
+
cur_stream.wait_stream(get_cmo_stream())
|
581
|
+
|
582
|
+
|
504
583
|
def set_random_seed(seed: int) -> None:
|
505
584
|
"""Set the random seed for all libraries."""
|
506
585
|
random.seed(seed)
|
@@ -738,6 +817,25 @@ def load_image(
|
|
738
817
|
return image, image_size
|
739
818
|
|
740
819
|
|
820
|
+
def get_image_bytes(image_file: Union[str, bytes]):
|
821
|
+
if isinstance(image_file, bytes):
|
822
|
+
return image_file
|
823
|
+
elif image_file.startswith("http://") or image_file.startswith("https://"):
|
824
|
+
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
|
825
|
+
response = requests.get(image_file, timeout=timeout)
|
826
|
+
return response.content
|
827
|
+
elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
|
828
|
+
with open(image_file, "rb") as f:
|
829
|
+
return f.read()
|
830
|
+
elif image_file.startswith("data:"):
|
831
|
+
image_file = image_file.split(",")[1]
|
832
|
+
return pybase64.b64decode(image_file)
|
833
|
+
elif isinstance(image_file, str):
|
834
|
+
return pybase64.b64decode(image_file)
|
835
|
+
else:
|
836
|
+
raise NotImplementedError(f"Invalid image: {image_file}")
|
837
|
+
|
838
|
+
|
741
839
|
def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
|
742
840
|
# We import decord here to avoid a strange Segmentation fault (core dumped) issue.
|
743
841
|
from decord import VideoReader, cpu, gpu
|
@@ -793,6 +891,33 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
|
|
793
891
|
os.unlink(tmp_file.name)
|
794
892
|
|
795
893
|
|
894
|
+
def encode_video(video_path, frame_count_limit=None):
|
895
|
+
# Lazy import because decord is not available on some arm platforms.
|
896
|
+
from decord import VideoReader, cpu
|
897
|
+
|
898
|
+
if not os.path.exists(video_path):
|
899
|
+
logger.error(f"Video {video_path} does not exist")
|
900
|
+
return []
|
901
|
+
|
902
|
+
if frame_count_limit == 0:
|
903
|
+
return []
|
904
|
+
|
905
|
+
def uniform_sample(l, n):
|
906
|
+
gap = len(l) / n
|
907
|
+
idxs = [int(i * gap + gap / 2) for i in range(n)]
|
908
|
+
return [l[i] for i in idxs]
|
909
|
+
|
910
|
+
vr = VideoReader(video_path, ctx=cpu(0))
|
911
|
+
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
912
|
+
frame_indices = [i for i in range(0, len(vr), sample_fps)]
|
913
|
+
if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
|
914
|
+
frame_indices = uniform_sample(frame_indices, frame_count_limit)
|
915
|
+
|
916
|
+
frames = vr.get_batch(frame_indices).asnumpy()
|
917
|
+
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
|
918
|
+
return frames
|
919
|
+
|
920
|
+
|
796
921
|
def suppress_other_loggers():
|
797
922
|
warnings.filterwarnings(
|
798
923
|
"ignore", category=UserWarning, message="The given NumPy array is not writable"
|
@@ -935,6 +1060,13 @@ def set_ulimit(target_soft_limit=65535):
|
|
935
1060
|
logger.warning(f"Fail to set RLIMIT_STACK: {e}")
|
936
1061
|
|
937
1062
|
|
1063
|
+
def rank0_log(msg: str):
|
1064
|
+
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
1065
|
+
|
1066
|
+
if get_tensor_model_parallel_rank() == 0:
|
1067
|
+
logger.info(msg)
|
1068
|
+
|
1069
|
+
|
938
1070
|
def add_api_key_middleware(app, api_key: str):
|
939
1071
|
@app.middleware("http")
|
940
1072
|
async def authentication(request, call_next):
|
@@ -1149,7 +1281,7 @@ def pytorch_profile(name, func, *args, data_size=-1):
|
|
1149
1281
|
|
1150
1282
|
def get_zmq_socket(
|
1151
1283
|
context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool
|
1152
|
-
):
|
1284
|
+
) -> zmq.Socket:
|
1153
1285
|
mem = psutil.virtual_memory()
|
1154
1286
|
total_mem = mem.total / 1024**3
|
1155
1287
|
available_mem = mem.available / 1024**3
|
@@ -1393,6 +1525,32 @@ def get_npu_memory_capacity():
|
|
1393
1525
|
raise ImportError("torch_npu is required when run on npu device.")
|
1394
1526
|
|
1395
1527
|
|
1528
|
+
def get_cpu_memory_capacity():
|
1529
|
+
# Per-rank memory capacity cannot be determined for customized core settings
|
1530
|
+
if os.environ.get("SGLANG_CPU_OMP_THREADS_BIND", ""):
|
1531
|
+
return None
|
1532
|
+
n_numa_node: int = len(get_cpu_ids_by_node())
|
1533
|
+
if n_numa_node == 0:
|
1534
|
+
# Cannot determine NUMA config, fallback to total memory and avoid ZeroDivisionError.
|
1535
|
+
return float(psutil.virtual_memory().total // (1 << 20))
|
1536
|
+
try:
|
1537
|
+
numa_mem_list = list()
|
1538
|
+
file_prefix = "/sys/devices/system/node/"
|
1539
|
+
for numa_id in range(n_numa_node):
|
1540
|
+
file_meminfo = f"node{numa_id}/meminfo"
|
1541
|
+
with open(os.path.join(file_prefix, file_meminfo), "r") as f:
|
1542
|
+
# 1st line contains 'MemTotal'
|
1543
|
+
line = f.read().split("\n")[0]
|
1544
|
+
numa_mem_list.append(int(line.split()[3]))
|
1545
|
+
# Retrieved value in KB, need MB
|
1546
|
+
numa_mem = float(min(numa_mem_list) // 1024)
|
1547
|
+
return numa_mem
|
1548
|
+
except FileNotFoundError:
|
1549
|
+
numa_mem = psutil.virtual_memory().total / n_numa_node
|
1550
|
+
# Retrieved value in Byte, need MB
|
1551
|
+
return float(numa_mem // (1 << 20))
|
1552
|
+
|
1553
|
+
|
1396
1554
|
def get_device_memory_capacity(device: str = None):
|
1397
1555
|
if is_cuda():
|
1398
1556
|
gpu_mem = get_nvgpu_memory_capacity()
|
@@ -1402,6 +1560,8 @@ def get_device_memory_capacity(device: str = None):
|
|
1402
1560
|
gpu_mem = get_hpu_memory_capacity()
|
1403
1561
|
elif device == "npu":
|
1404
1562
|
gpu_mem = get_npu_memory_capacity()
|
1563
|
+
elif device == "cpu":
|
1564
|
+
gpu_mem = get_cpu_memory_capacity()
|
1405
1565
|
else:
|
1406
1566
|
# GPU memory is not known yet or no GPU is available.
|
1407
1567
|
gpu_mem = None
|
@@ -1421,6 +1581,7 @@ def init_custom_process_group(
|
|
1421
1581
|
store=None,
|
1422
1582
|
group_name=None,
|
1423
1583
|
pg_options=None,
|
1584
|
+
device_id=None,
|
1424
1585
|
):
|
1425
1586
|
from torch.distributed.distributed_c10d import (
|
1426
1587
|
Backend,
|
@@ -1474,6 +1635,7 @@ def init_custom_process_group(
|
|
1474
1635
|
group_name=group_name,
|
1475
1636
|
**{pg_options_param_name: pg_options},
|
1476
1637
|
timeout=timeout,
|
1638
|
+
device_id=device_id,
|
1477
1639
|
)
|
1478
1640
|
|
1479
1641
|
_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
|
@@ -1938,50 +2100,6 @@ def set_uvicorn_logging_configs():
|
|
1938
2100
|
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
|
1939
2101
|
|
1940
2102
|
|
1941
|
-
def get_ip() -> str:
|
1942
|
-
# SGLANG_HOST_IP env can be ignore
|
1943
|
-
host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
|
1944
|
-
if host_ip:
|
1945
|
-
return host_ip
|
1946
|
-
|
1947
|
-
# IP is not set, try to get it from the network interface
|
1948
|
-
|
1949
|
-
# try ipv4
|
1950
|
-
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
1951
|
-
try:
|
1952
|
-
s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
|
1953
|
-
return s.getsockname()[0]
|
1954
|
-
except Exception:
|
1955
|
-
pass
|
1956
|
-
|
1957
|
-
# try ipv6
|
1958
|
-
try:
|
1959
|
-
s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
|
1960
|
-
# Google's public DNS server, see
|
1961
|
-
# https://developers.google.com/speed/public-dns/docs/using#addresses
|
1962
|
-
s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
|
1963
|
-
return s.getsockname()[0]
|
1964
|
-
except Exception:
|
1965
|
-
pass
|
1966
|
-
|
1967
|
-
# try using hostname
|
1968
|
-
hostname = socket.gethostname()
|
1969
|
-
try:
|
1970
|
-
ip_addr = socket.gethostbyname(hostname)
|
1971
|
-
warnings.warn("using local ip address: {}".format(ip_addr))
|
1972
|
-
return ip_addr
|
1973
|
-
except Exception:
|
1974
|
-
pass
|
1975
|
-
|
1976
|
-
warnings.warn(
|
1977
|
-
"Failed to get the IP address, using 0.0.0.0 by default."
|
1978
|
-
"The value can be set by the environment variable"
|
1979
|
-
" SGLANG_HOST_IP or HOST_IP.",
|
1980
|
-
stacklevel=2,
|
1981
|
-
)
|
1982
|
-
return "0.0.0.0"
|
1983
|
-
|
1984
|
-
|
1985
2103
|
def get_open_port() -> int:
|
1986
2104
|
port = os.getenv("SGLANG_PORT")
|
1987
2105
|
if port is not None:
|
@@ -2238,16 +2356,9 @@ def bind_or_assign(target, source):
|
|
2238
2356
|
return source
|
2239
2357
|
|
2240
2358
|
|
2241
|
-
def
|
2242
|
-
interface
|
2243
|
-
|
2244
|
-
get_local_ip_by_nic(interface)
|
2245
|
-
if interface is not None
|
2246
|
-
else get_local_ip_by_remote()
|
2247
|
-
)
|
2248
|
-
|
2249
|
-
|
2250
|
-
def get_local_ip_by_nic(interface: str) -> str:
|
2359
|
+
def get_local_ip_by_nic(interface: str = None) -> Optional[str]:
|
2360
|
+
if not (interface := interface or os.environ.get("SGLANG_LOCAL_IP_NIC", None)):
|
2361
|
+
return None
|
2251
2362
|
try:
|
2252
2363
|
import netifaces
|
2253
2364
|
except ImportError as e:
|
@@ -2268,15 +2379,13 @@ def get_local_ip_by_nic(interface: str) -> str:
|
|
2268
2379
|
if ip and not ip.startswith("fe80::") and ip != "::1":
|
2269
2380
|
return ip.split("%")[0]
|
2270
2381
|
except (ValueError, OSError) as e:
|
2271
|
-
|
2272
|
-
"Can not get local ip from NIC. Please verify whether SGLANG_LOCAL_IP_NIC is set correctly."
|
2382
|
+
logger.warning(
|
2383
|
+
f"{e} Can not get local ip from NIC. Please verify whether SGLANG_LOCAL_IP_NIC is set correctly."
|
2273
2384
|
)
|
2274
|
-
|
2275
|
-
# Fallback
|
2276
|
-
return get_local_ip_by_remote()
|
2385
|
+
return None
|
2277
2386
|
|
2278
2387
|
|
2279
|
-
def get_local_ip_by_remote() -> str:
|
2388
|
+
def get_local_ip_by_remote() -> Optional[str]:
|
2280
2389
|
# try ipv4
|
2281
2390
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
2282
2391
|
try:
|
@@ -2301,7 +2410,51 @@ def get_local_ip_by_remote() -> str:
|
|
2301
2410
|
s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
|
2302
2411
|
return s.getsockname()[0]
|
2303
2412
|
except Exception:
|
2304
|
-
|
2413
|
+
logger.warning("Can not get local ip by remote")
|
2414
|
+
return None
|
2415
|
+
|
2416
|
+
|
2417
|
+
def get_local_ip_auto(fallback: str = None) -> str:
|
2418
|
+
"""
|
2419
|
+
Automatically detect the local IP address using multiple fallback strategies.
|
2420
|
+
|
2421
|
+
This function attempts to obtain the local IP address through several methods.
|
2422
|
+
If all methods fail, it returns the specified fallback value or raises an exception.
|
2423
|
+
|
2424
|
+
Args:
|
2425
|
+
fallback (str, optional): Fallback IP address to return if all detection
|
2426
|
+
methods fail. For server applications, explicitly set this to
|
2427
|
+
"0.0.0.0" (IPv4) or "::" (IPv6) to bind to all available interfaces.
|
2428
|
+
Defaults to None.
|
2429
|
+
|
2430
|
+
Returns:
|
2431
|
+
str: The detected local IP address, or the fallback value if detection fails.
|
2432
|
+
|
2433
|
+
Raises:
|
2434
|
+
ValueError: If IP detection fails and no fallback value is provided.
|
2435
|
+
|
2436
|
+
Note:
|
2437
|
+
The function tries detection methods in the following order:
|
2438
|
+
1. Direct IP detection via get_ip()
|
2439
|
+
2. Network interface enumeration via get_local_ip_by_nic()
|
2440
|
+
3. Remote connection method via get_local_ip_by_remote()
|
2441
|
+
"""
|
2442
|
+
# Try environment variable
|
2443
|
+
host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
|
2444
|
+
if host_ip:
|
2445
|
+
return host_ip
|
2446
|
+
logger.debug("get_ip failed")
|
2447
|
+
# Fallback
|
2448
|
+
if ip := get_local_ip_by_nic():
|
2449
|
+
return ip
|
2450
|
+
logger.debug("get_local_ip_by_nic failed")
|
2451
|
+
# Fallback
|
2452
|
+
if ip := get_local_ip_by_remote():
|
2453
|
+
return ip
|
2454
|
+
logger.debug("get_local_ip_by_remote failed")
|
2455
|
+
if fallback:
|
2456
|
+
return fallback
|
2457
|
+
raise ValueError("Can not get local ip")
|
2305
2458
|
|
2306
2459
|
|
2307
2460
|
def is_page_size_one(server_args):
|
@@ -2353,7 +2506,7 @@ class BumpAllocator:
|
|
2353
2506
|
def log_info_on_rank0(logger, msg):
|
2354
2507
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
2355
2508
|
|
2356
|
-
if get_tensor_model_parallel_rank() == 0:
|
2509
|
+
if torch.distributed.is_initialized() and get_tensor_model_parallel_rank() == 0:
|
2357
2510
|
logger.info(msg)
|
2358
2511
|
|
2359
2512
|
|
@@ -2483,14 +2636,6 @@ def read_system_prompt_from_file(model_name: str) -> str:
|
|
2483
2636
|
return ""
|
2484
2637
|
|
2485
2638
|
|
2486
|
-
def bind_or_assign(target, source):
|
2487
|
-
if target is not None:
|
2488
|
-
target.copy_(source)
|
2489
|
-
return target
|
2490
|
-
else:
|
2491
|
-
return source
|
2492
|
-
|
2493
|
-
|
2494
2639
|
def prepack_weight_if_needed(weight):
|
2495
2640
|
if weight.device != torch.device("cpu"):
|
2496
2641
|
return weight
|
@@ -3027,3 +3172,232 @@ def check_cuda_result(raw_output):
|
|
3027
3172
|
raise Exception(f"CUDA error: {err}")
|
3028
3173
|
|
3029
3174
|
return results
|
3175
|
+
|
3176
|
+
|
3177
|
+
def get_physical_device_id(pytorch_device_id: int) -> int:
|
3178
|
+
"""
|
3179
|
+
Convert PyTorch logical device ID to physical device ID.
|
3180
|
+
"""
|
3181
|
+
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
3182
|
+
assert (
|
3183
|
+
cuda_visible_devices is not None
|
3184
|
+
), "CUDA_VISIBLE_DEVICES should be set in a scheduler"
|
3185
|
+
device_list = cuda_visible_devices.split(",")
|
3186
|
+
assert (
|
3187
|
+
len(device_list) == 1
|
3188
|
+
), "CUDA_VISIBLE_DEVICES should be set to a single device in a scheduler"
|
3189
|
+
return int(device_list[0])
|
3190
|
+
|
3191
|
+
|
3192
|
+
def get_device_sm_nvidia_smi():
|
3193
|
+
try:
|
3194
|
+
# Run nvidia-smi command and capture output
|
3195
|
+
result = subprocess.run(
|
3196
|
+
["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader"],
|
3197
|
+
capture_output=True,
|
3198
|
+
text=True,
|
3199
|
+
check=True,
|
3200
|
+
)
|
3201
|
+
|
3202
|
+
# Get the first line of output (assuming at least one GPU exists)
|
3203
|
+
compute_cap_str = result.stdout.strip().split("\n")[0]
|
3204
|
+
|
3205
|
+
# Convert string (e.g., "9.0") to tuple of integers (9, 0)
|
3206
|
+
major, minor = map(int, compute_cap_str.split("."))
|
3207
|
+
return (major, minor)
|
3208
|
+
|
3209
|
+
except (subprocess.CalledProcessError, FileNotFoundError, ValueError) as e:
|
3210
|
+
# Handle cases where nvidia-smi isn't available or output is unexpected
|
3211
|
+
print(f"Error getting compute capability: {e}")
|
3212
|
+
return (0, 0) # Default/fallback value
|
3213
|
+
|
3214
|
+
|
3215
|
+
def numa_bind_to_node(node: int):
|
3216
|
+
libnuma = ctypes.CDLL("libnuma.so")
|
3217
|
+
if libnuma.numa_available() < 0:
|
3218
|
+
raise SystemError("numa not available on this system")
|
3219
|
+
|
3220
|
+
libnuma.numa_run_on_node(ctypes.c_int(node))
|
3221
|
+
libnuma.numa_set_localalloc()
|
3222
|
+
|
3223
|
+
|
3224
|
+
def json_list_type(value):
|
3225
|
+
try:
|
3226
|
+
return json.loads(value)
|
3227
|
+
except json.JSONDecodeError:
|
3228
|
+
raise argparse.ArgumentTypeError(
|
3229
|
+
f"Invalid JSON list: {value}. Please provide a valid JSON list."
|
3230
|
+
)
|
3231
|
+
|
3232
|
+
|
3233
|
+
@contextmanager
|
3234
|
+
def temp_set_cuda_visible_devices(gpu_id: int):
|
3235
|
+
original_cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
|
3236
|
+
if original_cuda_visible_devices:
|
3237
|
+
cuda_visible_devices = original_cuda_visible_devices.split(",")
|
3238
|
+
else:
|
3239
|
+
cuda_visible_devices = []
|
3240
|
+
|
3241
|
+
str_gpu_id = cuda_visible_devices[gpu_id] if cuda_visible_devices else str(gpu_id)
|
3242
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = str_gpu_id
|
3243
|
+
yield
|
3244
|
+
if original_cuda_visible_devices:
|
3245
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = original_cuda_visible_devices
|
3246
|
+
else:
|
3247
|
+
del os.environ["CUDA_VISIBLE_DEVICES"]
|
3248
|
+
|
3249
|
+
|
3250
|
+
def get_extend_input_len_swa_limit(
|
3251
|
+
sliding_window_size: int, chunked_prefill_size: int, page_size: int
|
3252
|
+
) -> int:
|
3253
|
+
# 1. a factor of 2x is because each prefill contains chunked_prefill_size tokens,
|
3254
|
+
# and between prefills, we run swa_radix_cache.cache_unfinished_req(),
|
3255
|
+
# so we unlock the previously locked nodes.
|
3256
|
+
# 2. max is to handle the case that chunked_prefill_size is larger than sliding_window_size.
|
3257
|
+
# in that case, each prefill contains chunked_prefill_size tokens,
|
3258
|
+
# and we can only free out-of-sliding-window kv indices after each prefill.
|
3259
|
+
# 3. page_size is because we want to have 1 token extra for generated tokens.
|
3260
|
+
return page_size + 2 * max(sliding_window_size, chunked_prefill_size)
|
3261
|
+
|
3262
|
+
|
3263
|
+
def get_num_new_pages(
|
3264
|
+
seq_lens: torch.Tensor,
|
3265
|
+
page_size: int,
|
3266
|
+
prefix_lens: Optional[torch.Tensor] = None,
|
3267
|
+
decode: bool = False,
|
3268
|
+
) -> torch.Tensor:
|
3269
|
+
"""
|
3270
|
+
Get the number of new pages for the given prefix and sequence lengths.
|
3271
|
+
We use cpu tensors to avoid blocking kernel launch.
|
3272
|
+
"""
|
3273
|
+
cpu_device = torch.device("cpu")
|
3274
|
+
assert seq_lens.device == cpu_device
|
3275
|
+
|
3276
|
+
if prefix_lens is None or decode:
|
3277
|
+
# NOTE: Special case for handling decode, which prefix lens is `seq_lens - 1`.
|
3278
|
+
assert decode
|
3279
|
+
return (seq_lens % page_size == 1).int().sum().item()
|
3280
|
+
|
3281
|
+
assert prefix_lens.device == cpu_device
|
3282
|
+
num_pages_after = (seq_lens + page_size - 1) // page_size
|
3283
|
+
num_pages_before = (prefix_lens + page_size - 1) // page_size
|
3284
|
+
num_new_pages = num_pages_after - num_pages_before
|
3285
|
+
sum_num_new_pages = torch.sum(num_new_pages).to(torch.int64)
|
3286
|
+
return sum_num_new_pages.item()
|
3287
|
+
|
3288
|
+
|
3289
|
+
class CachedKernel:
|
3290
|
+
"""
|
3291
|
+
Wrapper that allows kernel[grid](...) syntax with caching based on a key function.
|
3292
|
+
|
3293
|
+
This wrapper caches compiled Triton kernels based on keys extracted by a
|
3294
|
+
user-provided key function to avoid redundant compilations.
|
3295
|
+
"""
|
3296
|
+
|
3297
|
+
def __init__(self, fn, key_fn=None):
|
3298
|
+
self.fn = fn
|
3299
|
+
assert isinstance(fn, triton.runtime.jit.JITFunction)
|
3300
|
+
|
3301
|
+
original_fn = fn.fn
|
3302
|
+
self.signature = inspect.signature(original_fn)
|
3303
|
+
self.param_names = tuple(self.signature.parameters.keys())
|
3304
|
+
self.num_args = len(self.param_names)
|
3305
|
+
|
3306
|
+
# Check that no parameters have default values
|
3307
|
+
for name, param in self.signature.parameters.items():
|
3308
|
+
assert (
|
3309
|
+
param.default is inspect.Parameter.empty
|
3310
|
+
), f"Parameter '{name}' has a default value. Default parameters are not supported in cached kernels."
|
3311
|
+
|
3312
|
+
functools.update_wrapper(self, original_fn)
|
3313
|
+
self.kernel_cache = {}
|
3314
|
+
|
3315
|
+
# Store the key function
|
3316
|
+
self.key_fn = key_fn
|
3317
|
+
|
3318
|
+
def __getitem__(self, grid):
|
3319
|
+
"""
|
3320
|
+
Index with grid to get a launcher function.
|
3321
|
+
Returns a launcher that will handle caching based on the key function.
|
3322
|
+
"""
|
3323
|
+
assert (
|
3324
|
+
isinstance(grid, tuple) and len(grid) <= 3
|
3325
|
+
), "Grid must be a tuple with at most 3 dimensions."
|
3326
|
+
|
3327
|
+
# Normalize grid once
|
3328
|
+
if len(grid) < 3:
|
3329
|
+
grid = grid + (1,) * (3 - len(grid))
|
3330
|
+
|
3331
|
+
def launcher(*args, **kwargs):
|
3332
|
+
cache_key = self.key_fn(args, kwargs)
|
3333
|
+
|
3334
|
+
cached_kernel = self.kernel_cache.get(cache_key)
|
3335
|
+
|
3336
|
+
if cached_kernel is None:
|
3337
|
+
# First time: compile and cache the kernel
|
3338
|
+
cached_kernel = self.fn[grid](*args, **kwargs)
|
3339
|
+
self.kernel_cache[cache_key] = cached_kernel
|
3340
|
+
return cached_kernel
|
3341
|
+
else:
|
3342
|
+
# Use cached kernel
|
3343
|
+
all_args = self._build_args(args, kwargs)
|
3344
|
+
cached_kernel[grid](*all_args)
|
3345
|
+
return cached_kernel
|
3346
|
+
|
3347
|
+
return launcher
|
3348
|
+
|
3349
|
+
def _build_args(self, args, kwargs):
|
3350
|
+
"""
|
3351
|
+
Build the complete argument list for kernel invocation.
|
3352
|
+
"""
|
3353
|
+
complete_args = list(args)
|
3354
|
+
|
3355
|
+
for i in range(len(args), self.num_args):
|
3356
|
+
name = self.param_names[i]
|
3357
|
+
value = kwargs.get(name, inspect.Parameter.empty)
|
3358
|
+
if value is not inspect.Parameter.empty:
|
3359
|
+
complete_args.append(value)
|
3360
|
+
else:
|
3361
|
+
raise ValueError(f"Missing argument: {name}")
|
3362
|
+
|
3363
|
+
return complete_args
|
3364
|
+
|
3365
|
+
def _clear_cache(self):
|
3366
|
+
"""
|
3367
|
+
Clear the kernel cache for testing purposes.
|
3368
|
+
"""
|
3369
|
+
self.kernel_cache.clear()
|
3370
|
+
|
3371
|
+
|
3372
|
+
def cached_triton_kernel(key_fn=None):
|
3373
|
+
"""
|
3374
|
+
Decorator that enables key-based caching for Triton kernels using a key function.
|
3375
|
+
|
3376
|
+
It essentially bypasses Triton's built-in caching mechanism, allowing users to
|
3377
|
+
define their own caching strategy based on kernel parameters. This helps reduce
|
3378
|
+
the heavy overheads of Triton kernel launch when the kernel specialization dispatch
|
3379
|
+
is simple.
|
3380
|
+
|
3381
|
+
Usage:
|
3382
|
+
@cached_triton_kernel(key_fn=lambda args, kwargs: kwargs.get('BLOCK_SIZE', 1024))
|
3383
|
+
@triton.jit
|
3384
|
+
def my_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr):
|
3385
|
+
...
|
3386
|
+
|
3387
|
+
# Invoke normally
|
3388
|
+
my_kernel[grid](x, y, BLOCK_SIZE=1024)
|
3389
|
+
|
3390
|
+
Args:
|
3391
|
+
key_fn: A function that takes (args, kwargs) and returns the cache key(s).
|
3392
|
+
The key can be a single value or a tuple of values.
|
3393
|
+
|
3394
|
+
Returns:
|
3395
|
+
A decorator that wraps the kernel with caching functionality.
|
3396
|
+
|
3397
|
+
Note: Kernels with default parameter values are not supported and will raise an assertion error.
|
3398
|
+
"""
|
3399
|
+
|
3400
|
+
def decorator(fn):
|
3401
|
+
return CachedKernel(fn, key_fn)
|
3402
|
+
|
3403
|
+
return decorator
|