sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -15,12 +15,14 @@
|
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
|
+
import argparse
|
18
19
|
import asyncio
|
19
20
|
import builtins
|
20
21
|
import ctypes
|
21
22
|
import dataclasses
|
22
23
|
import functools
|
23
24
|
import importlib
|
25
|
+
import inspect
|
24
26
|
import io
|
25
27
|
import ipaddress
|
26
28
|
import itertools
|
@@ -81,11 +83,9 @@ from packaging import version as pkg_version
|
|
81
83
|
from PIL import Image
|
82
84
|
from starlette.routing import Mount
|
83
85
|
from torch import nn
|
84
|
-
from torch.func import functional_call
|
85
86
|
from torch.library import Library
|
86
87
|
from torch.profiler import ProfilerActivity, profile, record_function
|
87
88
|
from torch.utils._contextlib import _DecoratorContextManager
|
88
|
-
from triton.runtime.cache import FileCacheManager
|
89
89
|
from typing_extensions import Literal
|
90
90
|
|
91
91
|
from sglang.srt.metrics.func_timer import enable_func_timer
|
@@ -166,6 +166,7 @@ is_ampere_with_cuda_12_3 = lambda: _check(8)
|
|
166
166
|
is_hopper_with_cuda_12_3 = lambda: _check(9)
|
167
167
|
|
168
168
|
|
169
|
+
@lru_cache(maxsize=1)
|
169
170
|
def is_blackwell():
|
170
171
|
if not is_cuda():
|
171
172
|
return False
|
@@ -174,6 +175,8 @@ def is_blackwell():
|
|
174
175
|
|
175
176
|
@lru_cache(maxsize=1)
|
176
177
|
def is_sm100_supported(device=None) -> bool:
|
178
|
+
if not is_cuda_alike():
|
179
|
+
return False
|
177
180
|
return (torch.cuda.get_device_capability(device)[0] == 10) and (
|
178
181
|
torch.version.cuda >= "12.8"
|
179
182
|
)
|
@@ -181,6 +184,8 @@ def is_sm100_supported(device=None) -> bool:
|
|
181
184
|
|
182
185
|
@lru_cache(maxsize=1)
|
183
186
|
def is_sm90_supported(device=None) -> bool:
|
187
|
+
if not is_cuda_alike():
|
188
|
+
return False
|
184
189
|
return (torch.cuda.get_device_capability(device)[0] == 9) and (
|
185
190
|
torch.version.cuda >= "12.3"
|
186
191
|
)
|
@@ -190,6 +195,7 @@ _warned_bool_env_var_keys = set()
|
|
190
195
|
|
191
196
|
|
192
197
|
def get_bool_env_var(name: str, default: str = "false") -> bool:
|
198
|
+
# FIXME: move your environment variable to sglang.srt.environ
|
193
199
|
value = os.getenv(name, default)
|
194
200
|
value = value.lower()
|
195
201
|
|
@@ -207,6 +213,7 @@ def get_bool_env_var(name: str, default: str = "false") -> bool:
|
|
207
213
|
|
208
214
|
|
209
215
|
def get_int_env_var(name: str, default: int = 0) -> int:
|
216
|
+
# FIXME: move your environment variable to sglang.srt.environ
|
210
217
|
value = os.getenv(name)
|
211
218
|
if value is None or not value.strip():
|
212
219
|
return default
|
@@ -230,8 +237,16 @@ except:
|
|
230
237
|
is_intel_amx_backend_available = False
|
231
238
|
|
232
239
|
|
240
|
+
try:
|
241
|
+
# move torch._C._cpu._is_amx_tile_supported() from cpu_has_amx_support
|
242
|
+
# to support torch compile
|
243
|
+
is_amx_tile_supported = torch._C._cpu._is_amx_tile_supported()
|
244
|
+
except:
|
245
|
+
is_amx_tile_supported = False
|
246
|
+
|
247
|
+
|
233
248
|
def cpu_has_amx_support():
|
234
|
-
return
|
249
|
+
return is_amx_tile_supported and is_intel_amx_backend_available
|
235
250
|
|
236
251
|
|
237
252
|
def use_intel_amx_backend(layer):
|
@@ -426,7 +441,9 @@ def get_available_gpu_memory(
|
|
426
441
|
|
427
442
|
elif device == "cpu":
|
428
443
|
# TODO: rename the variables in the current function to be not GPU specific
|
429
|
-
|
444
|
+
total_free_memory = psutil.virtual_memory().available
|
445
|
+
n_numa_node: int = len(get_cpu_ids_by_node())
|
446
|
+
free_gpu_memory = round(total_free_memory / n_numa_node, 3)
|
430
447
|
elif device == "npu":
|
431
448
|
num_gpus = torch.npu.device_count()
|
432
449
|
assert gpu_id < num_gpus
|
@@ -454,7 +471,7 @@ def is_pin_memory_available() -> bool:
|
|
454
471
|
|
455
472
|
class LayerFn(Protocol):
|
456
473
|
|
457
|
-
def __call__(self,
|
474
|
+
def __call__(self, idx: int, prefix: str) -> torch.nn.Module: ...
|
458
475
|
|
459
476
|
|
460
477
|
def make_layers(
|
@@ -465,7 +482,7 @@ def make_layers(
|
|
465
482
|
prefix: str = "",
|
466
483
|
return_tuple: bool = False,
|
467
484
|
offloader_kwargs: Dict[str, Any] = {},
|
468
|
-
) -> Tuple[
|
485
|
+
) -> Tuple[torch.nn.Module, int, int]:
|
469
486
|
"""Make a list of layers with the given layer function"""
|
470
487
|
# circula imports
|
471
488
|
from sglang.srt.distributed import get_pp_indices
|
@@ -501,6 +518,50 @@ def make_layers(
|
|
501
518
|
return modules, start_layer, end_layer
|
502
519
|
|
503
520
|
|
521
|
+
cmo_stream = None
|
522
|
+
|
523
|
+
|
524
|
+
def get_cmo_stream():
|
525
|
+
"""
|
526
|
+
Cache Management Operation(CMO).
|
527
|
+
Launch a new stream to prefetch the weight of matmul when running other
|
528
|
+
AIV or communication kernels, aiming to overlap the memory access time.
|
529
|
+
"""
|
530
|
+
global cmo_stream
|
531
|
+
if cmo_stream is None:
|
532
|
+
cmo_stream = torch.get_device_module().Stream()
|
533
|
+
return cmo_stream
|
534
|
+
|
535
|
+
|
536
|
+
def prepare_weight_cache(handle, cache):
|
537
|
+
import torch_npu
|
538
|
+
|
539
|
+
NPU_PREFETCH_MAX_SIZE_BYTES = (
|
540
|
+
1000000000 # 1GB, a large value to prefetch entire weight
|
541
|
+
)
|
542
|
+
stream = get_cmo_stream()
|
543
|
+
stream.wait_stream(torch.npu.current_stream())
|
544
|
+
with torch.npu.stream(stream):
|
545
|
+
if isinstance(cache, list):
|
546
|
+
for weight in cache:
|
547
|
+
torch_npu.npu_prefetch(
|
548
|
+
weight,
|
549
|
+
handle,
|
550
|
+
NPU_PREFETCH_MAX_SIZE_BYTES,
|
551
|
+
)
|
552
|
+
else:
|
553
|
+
torch_npu.npu_prefetch(
|
554
|
+
cache,
|
555
|
+
handle,
|
556
|
+
NPU_PREFETCH_MAX_SIZE_BYTES,
|
557
|
+
)
|
558
|
+
|
559
|
+
|
560
|
+
def wait_cmo_stream():
|
561
|
+
cur_stream = torch.get_device_module().current_stream()
|
562
|
+
cur_stream.wait_stream(get_cmo_stream())
|
563
|
+
|
564
|
+
|
504
565
|
def set_random_seed(seed: int) -> None:
|
505
566
|
"""Set the random seed for all libraries."""
|
506
567
|
random.seed(seed)
|
@@ -738,6 +799,25 @@ def load_image(
|
|
738
799
|
return image, image_size
|
739
800
|
|
740
801
|
|
802
|
+
def get_image_bytes(image_file: Union[str, bytes]):
|
803
|
+
if isinstance(image_file, bytes):
|
804
|
+
return image_file
|
805
|
+
elif image_file.startswith("http://") or image_file.startswith("https://"):
|
806
|
+
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
|
807
|
+
response = requests.get(image_file, timeout=timeout)
|
808
|
+
return response.content
|
809
|
+
elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
|
810
|
+
with open(image_file, "rb") as f:
|
811
|
+
return f.read()
|
812
|
+
elif image_file.startswith("data:"):
|
813
|
+
image_file = image_file.split(",")[1]
|
814
|
+
return pybase64.b64decode(image_file)
|
815
|
+
elif isinstance(image_file, str):
|
816
|
+
return pybase64.b64decode(image_file)
|
817
|
+
else:
|
818
|
+
raise NotImplementedError(f"Invalid image: {image_file}")
|
819
|
+
|
820
|
+
|
741
821
|
def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
|
742
822
|
# We import decord here to avoid a strange Segmentation fault (core dumped) issue.
|
743
823
|
from decord import VideoReader, cpu, gpu
|
@@ -793,6 +873,33 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
|
|
793
873
|
os.unlink(tmp_file.name)
|
794
874
|
|
795
875
|
|
876
|
+
def encode_video(video_path, frame_count_limit=None):
|
877
|
+
# Lazy import because decord is not available on some arm platforms.
|
878
|
+
from decord import VideoReader, cpu
|
879
|
+
|
880
|
+
if not os.path.exists(video_path):
|
881
|
+
logger.error(f"Video {video_path} does not exist")
|
882
|
+
return []
|
883
|
+
|
884
|
+
if frame_count_limit == 0:
|
885
|
+
return []
|
886
|
+
|
887
|
+
def uniform_sample(l, n):
|
888
|
+
gap = len(l) / n
|
889
|
+
idxs = [int(i * gap + gap / 2) for i in range(n)]
|
890
|
+
return [l[i] for i in idxs]
|
891
|
+
|
892
|
+
vr = VideoReader(video_path, ctx=cpu(0))
|
893
|
+
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
894
|
+
frame_indices = [i for i in range(0, len(vr), sample_fps)]
|
895
|
+
if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
|
896
|
+
frame_indices = uniform_sample(frame_indices, frame_count_limit)
|
897
|
+
|
898
|
+
frames = vr.get_batch(frame_indices).asnumpy()
|
899
|
+
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
|
900
|
+
return frames
|
901
|
+
|
902
|
+
|
796
903
|
def suppress_other_loggers():
|
797
904
|
warnings.filterwarnings(
|
798
905
|
"ignore", category=UserWarning, message="The given NumPy array is not writable"
|
@@ -935,6 +1042,13 @@ def set_ulimit(target_soft_limit=65535):
|
|
935
1042
|
logger.warning(f"Fail to set RLIMIT_STACK: {e}")
|
936
1043
|
|
937
1044
|
|
1045
|
+
def rank0_log(msg: str):
|
1046
|
+
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
1047
|
+
|
1048
|
+
if get_tensor_model_parallel_rank() == 0:
|
1049
|
+
logger.info(msg)
|
1050
|
+
|
1051
|
+
|
938
1052
|
def add_api_key_middleware(app, api_key: str):
|
939
1053
|
@app.middleware("http")
|
940
1054
|
async def authentication(request, call_next):
|
@@ -1149,7 +1263,7 @@ def pytorch_profile(name, func, *args, data_size=-1):
|
|
1149
1263
|
|
1150
1264
|
def get_zmq_socket(
|
1151
1265
|
context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool
|
1152
|
-
):
|
1266
|
+
) -> zmq.Socket:
|
1153
1267
|
mem = psutil.virtual_memory()
|
1154
1268
|
total_mem = mem.total / 1024**3
|
1155
1269
|
available_mem = mem.available / 1024**3
|
@@ -1393,6 +1507,32 @@ def get_npu_memory_capacity():
|
|
1393
1507
|
raise ImportError("torch_npu is required when run on npu device.")
|
1394
1508
|
|
1395
1509
|
|
1510
|
+
def get_cpu_memory_capacity():
|
1511
|
+
# Per-rank memory capacity cannot be determined for customized core settings
|
1512
|
+
if os.environ.get("SGLANG_CPU_OMP_THREADS_BIND", ""):
|
1513
|
+
return None
|
1514
|
+
n_numa_node: int = len(get_cpu_ids_by_node())
|
1515
|
+
if n_numa_node == 0:
|
1516
|
+
# Cannot determine NUMA config, fallback to total memory and avoid ZeroDivisionError.
|
1517
|
+
return float(psutil.virtual_memory().total // (1 << 20))
|
1518
|
+
try:
|
1519
|
+
numa_mem_list = list()
|
1520
|
+
file_prefix = "/sys/devices/system/node/"
|
1521
|
+
for numa_id in range(n_numa_node):
|
1522
|
+
file_meminfo = f"node{numa_id}/meminfo"
|
1523
|
+
with open(os.path.join(file_prefix, file_meminfo), "r") as f:
|
1524
|
+
# 1st line contains 'MemTotal'
|
1525
|
+
line = f.read().split("\n")[0]
|
1526
|
+
numa_mem_list.append(int(line.split()[3]))
|
1527
|
+
# Retrieved value in KB, need MB
|
1528
|
+
numa_mem = float(min(numa_mem_list) // 1024)
|
1529
|
+
return numa_mem
|
1530
|
+
except FileNotFoundError:
|
1531
|
+
numa_mem = psutil.virtual_memory().total / n_numa_node
|
1532
|
+
# Retrieved value in Byte, need MB
|
1533
|
+
return float(numa_mem // (1 << 20))
|
1534
|
+
|
1535
|
+
|
1396
1536
|
def get_device_memory_capacity(device: str = None):
|
1397
1537
|
if is_cuda():
|
1398
1538
|
gpu_mem = get_nvgpu_memory_capacity()
|
@@ -1402,6 +1542,8 @@ def get_device_memory_capacity(device: str = None):
|
|
1402
1542
|
gpu_mem = get_hpu_memory_capacity()
|
1403
1543
|
elif device == "npu":
|
1404
1544
|
gpu_mem = get_npu_memory_capacity()
|
1545
|
+
elif device == "cpu":
|
1546
|
+
gpu_mem = get_cpu_memory_capacity()
|
1405
1547
|
else:
|
1406
1548
|
# GPU memory is not known yet or no GPU is available.
|
1407
1549
|
gpu_mem = None
|
@@ -1421,6 +1563,7 @@ def init_custom_process_group(
|
|
1421
1563
|
store=None,
|
1422
1564
|
group_name=None,
|
1423
1565
|
pg_options=None,
|
1566
|
+
device_id=None,
|
1424
1567
|
):
|
1425
1568
|
from torch.distributed.distributed_c10d import (
|
1426
1569
|
Backend,
|
@@ -1474,6 +1617,7 @@ def init_custom_process_group(
|
|
1474
1617
|
group_name=group_name,
|
1475
1618
|
**{pg_options_param_name: pg_options},
|
1476
1619
|
timeout=timeout,
|
1620
|
+
device_id=device_id,
|
1477
1621
|
)
|
1478
1622
|
|
1479
1623
|
_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
|
@@ -1938,50 +2082,6 @@ def set_uvicorn_logging_configs():
|
|
1938
2082
|
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
|
1939
2083
|
|
1940
2084
|
|
1941
|
-
def get_ip() -> str:
|
1942
|
-
# SGLANG_HOST_IP env can be ignore
|
1943
|
-
host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
|
1944
|
-
if host_ip:
|
1945
|
-
return host_ip
|
1946
|
-
|
1947
|
-
# IP is not set, try to get it from the network interface
|
1948
|
-
|
1949
|
-
# try ipv4
|
1950
|
-
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
1951
|
-
try:
|
1952
|
-
s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
|
1953
|
-
return s.getsockname()[0]
|
1954
|
-
except Exception:
|
1955
|
-
pass
|
1956
|
-
|
1957
|
-
# try ipv6
|
1958
|
-
try:
|
1959
|
-
s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
|
1960
|
-
# Google's public DNS server, see
|
1961
|
-
# https://developers.google.com/speed/public-dns/docs/using#addresses
|
1962
|
-
s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
|
1963
|
-
return s.getsockname()[0]
|
1964
|
-
except Exception:
|
1965
|
-
pass
|
1966
|
-
|
1967
|
-
# try using hostname
|
1968
|
-
hostname = socket.gethostname()
|
1969
|
-
try:
|
1970
|
-
ip_addr = socket.gethostbyname(hostname)
|
1971
|
-
warnings.warn("using local ip address: {}".format(ip_addr))
|
1972
|
-
return ip_addr
|
1973
|
-
except Exception:
|
1974
|
-
pass
|
1975
|
-
|
1976
|
-
warnings.warn(
|
1977
|
-
"Failed to get the IP address, using 0.0.0.0 by default."
|
1978
|
-
"The value can be set by the environment variable"
|
1979
|
-
" SGLANG_HOST_IP or HOST_IP.",
|
1980
|
-
stacklevel=2,
|
1981
|
-
)
|
1982
|
-
return "0.0.0.0"
|
1983
|
-
|
1984
|
-
|
1985
2085
|
def get_open_port() -> int:
|
1986
2086
|
port = os.getenv("SGLANG_PORT")
|
1987
2087
|
if port is not None:
|
@@ -2238,16 +2338,9 @@ def bind_or_assign(target, source):
|
|
2238
2338
|
return source
|
2239
2339
|
|
2240
2340
|
|
2241
|
-
def
|
2242
|
-
interface
|
2243
|
-
|
2244
|
-
get_local_ip_by_nic(interface)
|
2245
|
-
if interface is not None
|
2246
|
-
else get_local_ip_by_remote()
|
2247
|
-
)
|
2248
|
-
|
2249
|
-
|
2250
|
-
def get_local_ip_by_nic(interface: str) -> str:
|
2341
|
+
def get_local_ip_by_nic(interface: str = None) -> Optional[str]:
|
2342
|
+
if not (interface := interface or os.environ.get("SGLANG_LOCAL_IP_NIC", None)):
|
2343
|
+
return None
|
2251
2344
|
try:
|
2252
2345
|
import netifaces
|
2253
2346
|
except ImportError as e:
|
@@ -2268,15 +2361,13 @@ def get_local_ip_by_nic(interface: str) -> str:
|
|
2268
2361
|
if ip and not ip.startswith("fe80::") and ip != "::1":
|
2269
2362
|
return ip.split("%")[0]
|
2270
2363
|
except (ValueError, OSError) as e:
|
2271
|
-
|
2272
|
-
"Can not get local ip from NIC. Please verify whether SGLANG_LOCAL_IP_NIC is set correctly."
|
2364
|
+
logger.warning(
|
2365
|
+
f"{e} Can not get local ip from NIC. Please verify whether SGLANG_LOCAL_IP_NIC is set correctly."
|
2273
2366
|
)
|
2274
|
-
|
2275
|
-
# Fallback
|
2276
|
-
return get_local_ip_by_remote()
|
2367
|
+
return None
|
2277
2368
|
|
2278
2369
|
|
2279
|
-
def get_local_ip_by_remote() -> str:
|
2370
|
+
def get_local_ip_by_remote() -> Optional[str]:
|
2280
2371
|
# try ipv4
|
2281
2372
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
2282
2373
|
try:
|
@@ -2301,7 +2392,51 @@ def get_local_ip_by_remote() -> str:
|
|
2301
2392
|
s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
|
2302
2393
|
return s.getsockname()[0]
|
2303
2394
|
except Exception:
|
2304
|
-
|
2395
|
+
logger.warning("Can not get local ip by remote")
|
2396
|
+
return None
|
2397
|
+
|
2398
|
+
|
2399
|
+
def get_local_ip_auto(fallback: str = None) -> str:
|
2400
|
+
"""
|
2401
|
+
Automatically detect the local IP address using multiple fallback strategies.
|
2402
|
+
|
2403
|
+
This function attempts to obtain the local IP address through several methods.
|
2404
|
+
If all methods fail, it returns the specified fallback value or raises an exception.
|
2405
|
+
|
2406
|
+
Args:
|
2407
|
+
fallback (str, optional): Fallback IP address to return if all detection
|
2408
|
+
methods fail. For server applications, explicitly set this to
|
2409
|
+
"0.0.0.0" (IPv4) or "::" (IPv6) to bind to all available interfaces.
|
2410
|
+
Defaults to None.
|
2411
|
+
|
2412
|
+
Returns:
|
2413
|
+
str: The detected local IP address, or the fallback value if detection fails.
|
2414
|
+
|
2415
|
+
Raises:
|
2416
|
+
ValueError: If IP detection fails and no fallback value is provided.
|
2417
|
+
|
2418
|
+
Note:
|
2419
|
+
The function tries detection methods in the following order:
|
2420
|
+
1. Direct IP detection via get_ip()
|
2421
|
+
2. Network interface enumeration via get_local_ip_by_nic()
|
2422
|
+
3. Remote connection method via get_local_ip_by_remote()
|
2423
|
+
"""
|
2424
|
+
# Try environment variable
|
2425
|
+
host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
|
2426
|
+
if host_ip:
|
2427
|
+
return host_ip
|
2428
|
+
logger.debug("get_ip failed")
|
2429
|
+
# Fallback
|
2430
|
+
if ip := get_local_ip_by_nic():
|
2431
|
+
return ip
|
2432
|
+
logger.debug("get_local_ip_by_nic failed")
|
2433
|
+
# Fallback
|
2434
|
+
if ip := get_local_ip_by_remote():
|
2435
|
+
return ip
|
2436
|
+
logger.debug("get_local_ip_by_remote failed")
|
2437
|
+
if fallback:
|
2438
|
+
return fallback
|
2439
|
+
raise ValueError("Can not get local ip")
|
2305
2440
|
|
2306
2441
|
|
2307
2442
|
def is_page_size_one(server_args):
|
@@ -2353,7 +2488,7 @@ class BumpAllocator:
|
|
2353
2488
|
def log_info_on_rank0(logger, msg):
|
2354
2489
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
2355
2490
|
|
2356
|
-
if get_tensor_model_parallel_rank() == 0:
|
2491
|
+
if torch.distributed.is_initialized() and get_tensor_model_parallel_rank() == 0:
|
2357
2492
|
logger.info(msg)
|
2358
2493
|
|
2359
2494
|
|
@@ -2483,14 +2618,6 @@ def read_system_prompt_from_file(model_name: str) -> str:
|
|
2483
2618
|
return ""
|
2484
2619
|
|
2485
2620
|
|
2486
|
-
def bind_or_assign(target, source):
|
2487
|
-
if target is not None:
|
2488
|
-
target.copy_(source)
|
2489
|
-
return target
|
2490
|
-
else:
|
2491
|
-
return source
|
2492
|
-
|
2493
|
-
|
2494
2621
|
def prepack_weight_if_needed(weight):
|
2495
2622
|
if weight.device != torch.device("cpu"):
|
2496
2623
|
return weight
|
@@ -3027,3 +3154,232 @@ def check_cuda_result(raw_output):
|
|
3027
3154
|
raise Exception(f"CUDA error: {err}")
|
3028
3155
|
|
3029
3156
|
return results
|
3157
|
+
|
3158
|
+
|
3159
|
+
def get_physical_device_id(pytorch_device_id: int) -> int:
|
3160
|
+
"""
|
3161
|
+
Convert PyTorch logical device ID to physical device ID.
|
3162
|
+
"""
|
3163
|
+
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
3164
|
+
assert (
|
3165
|
+
cuda_visible_devices is not None
|
3166
|
+
), "CUDA_VISIBLE_DEVICES should be set in a scheduler"
|
3167
|
+
device_list = cuda_visible_devices.split(",")
|
3168
|
+
assert (
|
3169
|
+
len(device_list) == 1
|
3170
|
+
), "CUDA_VISIBLE_DEVICES should be set to a single device in a scheduler"
|
3171
|
+
return int(device_list[0])
|
3172
|
+
|
3173
|
+
|
3174
|
+
def get_device_sm_nvidia_smi():
|
3175
|
+
try:
|
3176
|
+
# Run nvidia-smi command and capture output
|
3177
|
+
result = subprocess.run(
|
3178
|
+
["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader"],
|
3179
|
+
capture_output=True,
|
3180
|
+
text=True,
|
3181
|
+
check=True,
|
3182
|
+
)
|
3183
|
+
|
3184
|
+
# Get the first line of output (assuming at least one GPU exists)
|
3185
|
+
compute_cap_str = result.stdout.strip().split("\n")[0]
|
3186
|
+
|
3187
|
+
# Convert string (e.g., "9.0") to tuple of integers (9, 0)
|
3188
|
+
major, minor = map(int, compute_cap_str.split("."))
|
3189
|
+
return (major, minor)
|
3190
|
+
|
3191
|
+
except (subprocess.CalledProcessError, FileNotFoundError, ValueError) as e:
|
3192
|
+
# Handle cases where nvidia-smi isn't available or output is unexpected
|
3193
|
+
print(f"Error getting compute capability: {e}")
|
3194
|
+
return (0, 0) # Default/fallback value
|
3195
|
+
|
3196
|
+
|
3197
|
+
def numa_bind_to_node(node: int):
|
3198
|
+
libnuma = ctypes.CDLL("libnuma.so")
|
3199
|
+
if libnuma.numa_available() < 0:
|
3200
|
+
raise SystemError("numa not available on this system")
|
3201
|
+
|
3202
|
+
libnuma.numa_run_on_node(ctypes.c_int(node))
|
3203
|
+
libnuma.numa_set_localalloc()
|
3204
|
+
|
3205
|
+
|
3206
|
+
def json_list_type(value):
|
3207
|
+
try:
|
3208
|
+
return json.loads(value)
|
3209
|
+
except json.JSONDecodeError:
|
3210
|
+
raise argparse.ArgumentTypeError(
|
3211
|
+
f"Invalid JSON list: {value}. Please provide a valid JSON list."
|
3212
|
+
)
|
3213
|
+
|
3214
|
+
|
3215
|
+
@contextmanager
|
3216
|
+
def temp_set_cuda_visible_devices(gpu_id: int):
|
3217
|
+
original_cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
|
3218
|
+
if original_cuda_visible_devices:
|
3219
|
+
cuda_visible_devices = original_cuda_visible_devices.split(",")
|
3220
|
+
else:
|
3221
|
+
cuda_visible_devices = []
|
3222
|
+
|
3223
|
+
str_gpu_id = cuda_visible_devices[gpu_id] if cuda_visible_devices else str(gpu_id)
|
3224
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = str_gpu_id
|
3225
|
+
yield
|
3226
|
+
if original_cuda_visible_devices:
|
3227
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = original_cuda_visible_devices
|
3228
|
+
else:
|
3229
|
+
del os.environ["CUDA_VISIBLE_DEVICES"]
|
3230
|
+
|
3231
|
+
|
3232
|
+
def get_extend_input_len_swa_limit(
|
3233
|
+
sliding_window_size: int, chunked_prefill_size: int, page_size: int
|
3234
|
+
) -> int:
|
3235
|
+
# 1. a factor of 2x is because each prefill contains chunked_prefill_size tokens,
|
3236
|
+
# and between prefills, we run swa_radix_cache.cache_unfinished_req(),
|
3237
|
+
# so we unlock the previously locked nodes.
|
3238
|
+
# 2. max is to handle the case that chunked_prefill_size is larger than sliding_window_size.
|
3239
|
+
# in that case, each prefill contains chunked_prefill_size tokens,
|
3240
|
+
# and we can only free out-of-sliding-window kv indices after each prefill.
|
3241
|
+
# 3. page_size is because we want to have 1 token extra for generated tokens.
|
3242
|
+
return page_size + 2 * max(sliding_window_size, chunked_prefill_size)
|
3243
|
+
|
3244
|
+
|
3245
|
+
def get_num_new_pages(
|
3246
|
+
seq_lens: torch.Tensor,
|
3247
|
+
page_size: int,
|
3248
|
+
prefix_lens: Optional[torch.Tensor] = None,
|
3249
|
+
decode: bool = False,
|
3250
|
+
) -> torch.Tensor:
|
3251
|
+
"""
|
3252
|
+
Get the number of new pages for the given prefix and sequence lengths.
|
3253
|
+
We use cpu tensors to avoid blocking kernel launch.
|
3254
|
+
"""
|
3255
|
+
cpu_device = torch.device("cpu")
|
3256
|
+
assert seq_lens.device == cpu_device
|
3257
|
+
|
3258
|
+
if prefix_lens is None or decode:
|
3259
|
+
# NOTE: Special case for handling decode, which prefix lens is `seq_lens - 1`.
|
3260
|
+
assert decode
|
3261
|
+
return (seq_lens % page_size == 1).int().sum().item()
|
3262
|
+
|
3263
|
+
assert prefix_lens.device == cpu_device
|
3264
|
+
num_pages_after = (seq_lens + page_size - 1) // page_size
|
3265
|
+
num_pages_before = (prefix_lens + page_size - 1) // page_size
|
3266
|
+
num_new_pages = num_pages_after - num_pages_before
|
3267
|
+
sum_num_new_pages = torch.sum(num_new_pages).to(torch.int64)
|
3268
|
+
return sum_num_new_pages.item()
|
3269
|
+
|
3270
|
+
|
3271
|
+
class CachedKernel:
|
3272
|
+
"""
|
3273
|
+
Wrapper that allows kernel[grid](...) syntax with caching based on a key function.
|
3274
|
+
|
3275
|
+
This wrapper caches compiled Triton kernels based on keys extracted by a
|
3276
|
+
user-provided key function to avoid redundant compilations.
|
3277
|
+
"""
|
3278
|
+
|
3279
|
+
def __init__(self, fn, key_fn=None):
|
3280
|
+
self.fn = fn
|
3281
|
+
assert isinstance(fn, triton.runtime.jit.JITFunction)
|
3282
|
+
|
3283
|
+
original_fn = fn.fn
|
3284
|
+
self.signature = inspect.signature(original_fn)
|
3285
|
+
self.param_names = tuple(self.signature.parameters.keys())
|
3286
|
+
self.num_args = len(self.param_names)
|
3287
|
+
|
3288
|
+
# Check that no parameters have default values
|
3289
|
+
for name, param in self.signature.parameters.items():
|
3290
|
+
assert (
|
3291
|
+
param.default is inspect.Parameter.empty
|
3292
|
+
), f"Parameter '{name}' has a default value. Default parameters are not supported in cached kernels."
|
3293
|
+
|
3294
|
+
functools.update_wrapper(self, original_fn)
|
3295
|
+
self.kernel_cache = {}
|
3296
|
+
|
3297
|
+
# Store the key function
|
3298
|
+
self.key_fn = key_fn
|
3299
|
+
|
3300
|
+
def __getitem__(self, grid):
|
3301
|
+
"""
|
3302
|
+
Index with grid to get a launcher function.
|
3303
|
+
Returns a launcher that will handle caching based on the key function.
|
3304
|
+
"""
|
3305
|
+
assert (
|
3306
|
+
isinstance(grid, tuple) and len(grid) <= 3
|
3307
|
+
), "Grid must be a tuple with at most 3 dimensions."
|
3308
|
+
|
3309
|
+
# Normalize grid once
|
3310
|
+
if len(grid) < 3:
|
3311
|
+
grid = grid + (1,) * (3 - len(grid))
|
3312
|
+
|
3313
|
+
def launcher(*args, **kwargs):
|
3314
|
+
cache_key = self.key_fn(args, kwargs)
|
3315
|
+
|
3316
|
+
cached_kernel = self.kernel_cache.get(cache_key)
|
3317
|
+
|
3318
|
+
if cached_kernel is None:
|
3319
|
+
# First time: compile and cache the kernel
|
3320
|
+
cached_kernel = self.fn[grid](*args, **kwargs)
|
3321
|
+
self.kernel_cache[cache_key] = cached_kernel
|
3322
|
+
return cached_kernel
|
3323
|
+
else:
|
3324
|
+
# Use cached kernel
|
3325
|
+
all_args = self._build_args(args, kwargs)
|
3326
|
+
cached_kernel[grid](*all_args)
|
3327
|
+
return cached_kernel
|
3328
|
+
|
3329
|
+
return launcher
|
3330
|
+
|
3331
|
+
def _build_args(self, args, kwargs):
|
3332
|
+
"""
|
3333
|
+
Build the complete argument list for kernel invocation.
|
3334
|
+
"""
|
3335
|
+
complete_args = list(args)
|
3336
|
+
|
3337
|
+
for i in range(len(args), self.num_args):
|
3338
|
+
name = self.param_names[i]
|
3339
|
+
value = kwargs.get(name, inspect.Parameter.empty)
|
3340
|
+
if value is not inspect.Parameter.empty:
|
3341
|
+
complete_args.append(value)
|
3342
|
+
else:
|
3343
|
+
raise ValueError(f"Missing argument: {name}")
|
3344
|
+
|
3345
|
+
return complete_args
|
3346
|
+
|
3347
|
+
def _clear_cache(self):
|
3348
|
+
"""
|
3349
|
+
Clear the kernel cache for testing purposes.
|
3350
|
+
"""
|
3351
|
+
self.kernel_cache.clear()
|
3352
|
+
|
3353
|
+
|
3354
|
+
def cached_triton_kernel(key_fn=None):
|
3355
|
+
"""
|
3356
|
+
Decorator that enables key-based caching for Triton kernels using a key function.
|
3357
|
+
|
3358
|
+
It essentially bypasses Triton's built-in caching mechanism, allowing users to
|
3359
|
+
define their own caching strategy based on kernel parameters. This helps reduce
|
3360
|
+
the heavy overheads of Triton kernel launch when the kernel specialization dispatch
|
3361
|
+
is simple.
|
3362
|
+
|
3363
|
+
Usage:
|
3364
|
+
@cached_triton_kernel(key_fn=lambda args, kwargs: kwargs.get('BLOCK_SIZE', 1024))
|
3365
|
+
@triton.jit
|
3366
|
+
def my_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr):
|
3367
|
+
...
|
3368
|
+
|
3369
|
+
# Invoke normally
|
3370
|
+
my_kernel[grid](x, y, BLOCK_SIZE=1024)
|
3371
|
+
|
3372
|
+
Args:
|
3373
|
+
key_fn: A function that takes (args, kwargs) and returns the cache key(s).
|
3374
|
+
The key can be a single value or a tuple of values.
|
3375
|
+
|
3376
|
+
Returns:
|
3377
|
+
A decorator that wraps the kernel with caching functionality.
|
3378
|
+
|
3379
|
+
Note: Kernels with default parameter values are not supported and will raise an assertion error.
|
3380
|
+
"""
|
3381
|
+
|
3382
|
+
def decorator(fn):
|
3383
|
+
return CachedKernel(fn, key_fn)
|
3384
|
+
|
3385
|
+
return decorator
|