sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
 - sglang/bench_one_batch_server.py +41 -25
 - sglang/bench_serving.py +378 -160
 - sglang/check_env.py +1 -1
 - sglang/compile_deep_gemm.py +6 -2
 - sglang/global_config.py +1 -25
 - sglang/lang/api.py +6 -0
 - sglang/lang/interpreter.py +1 -0
 - sglang/lang/ir.py +13 -0
 - sglang/launch_server.py +10 -15
 - sglang/profiler.py +18 -1
 - sglang/srt/_custom_ops.py +1 -1
 - sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
 - sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
 - sglang/srt/compilation/backend.py +437 -0
 - sglang/srt/compilation/compilation_config.py +20 -0
 - sglang/srt/compilation/compilation_counter.py +47 -0
 - sglang/srt/compilation/compile.py +210 -0
 - sglang/srt/compilation/compiler_interface.py +503 -0
 - sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
 - sglang/srt/compilation/fix_functionalization.py +134 -0
 - sglang/srt/compilation/fx_utils.py +83 -0
 - sglang/srt/compilation/inductor_pass.py +140 -0
 - sglang/srt/compilation/pass_manager.py +66 -0
 - sglang/srt/compilation/piecewise_context_manager.py +40 -0
 - sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
 - sglang/srt/configs/__init__.py +4 -0
 - sglang/srt/configs/deepseek_ocr.py +262 -0
 - sglang/srt/configs/deepseekvl2.py +194 -96
 - sglang/srt/configs/dots_vlm.py +2 -7
 - sglang/srt/configs/falcon_h1.py +13 -64
 - sglang/srt/configs/load_config.py +25 -2
 - sglang/srt/configs/mamba_utils.py +117 -0
 - sglang/srt/configs/model_config.py +136 -25
 - sglang/srt/configs/modelopt_config.py +30 -0
 - sglang/srt/configs/nemotron_h.py +286 -0
 - sglang/srt/configs/olmo3.py +105 -0
 - sglang/srt/configs/points_v15_chat.py +29 -0
 - sglang/srt/configs/qwen3_next.py +11 -47
 - sglang/srt/configs/qwen3_omni.py +613 -0
 - sglang/srt/configs/qwen3_vl.py +0 -10
 - sglang/srt/connector/remote_instance.py +1 -1
 - sglang/srt/constrained/base_grammar_backend.py +5 -1
 - sglang/srt/constrained/llguidance_backend.py +5 -0
 - sglang/srt/constrained/outlines_backend.py +1 -1
 - sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
 - sglang/srt/constrained/utils.py +12 -0
 - sglang/srt/constrained/xgrammar_backend.py +20 -11
 - sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
 - sglang/srt/disaggregation/base/conn.py +17 -4
 - sglang/srt/disaggregation/common/conn.py +4 -2
 - sglang/srt/disaggregation/decode.py +123 -31
 - sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
 - sglang/srt/disaggregation/fake/conn.py +11 -3
 - sglang/srt/disaggregation/mooncake/conn.py +157 -19
 - sglang/srt/disaggregation/nixl/conn.py +69 -24
 - sglang/srt/disaggregation/prefill.py +96 -270
 - sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
 - sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
 - sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
 - sglang/srt/distributed/device_communicators/pynccl.py +24 -12
 - sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
 - sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
 - sglang/srt/distributed/naive_distributed.py +5 -4
 - sglang/srt/distributed/parallel_state.py +63 -19
 - sglang/srt/elastic_ep/elastic_ep.py +74 -0
 - sglang/srt/entrypoints/context.py +3 -2
 - sglang/srt/entrypoints/engine.py +83 -80
 - sglang/srt/entrypoints/grpc_server.py +430 -234
 - sglang/srt/entrypoints/harmony_utils.py +2 -2
 - sglang/srt/entrypoints/http_server.py +195 -102
 - sglang/srt/entrypoints/http_server_engine.py +1 -7
 - sglang/srt/entrypoints/openai/protocol.py +225 -37
 - sglang/srt/entrypoints/openai/serving_base.py +49 -2
 - sglang/srt/entrypoints/openai/serving_chat.py +29 -74
 - sglang/srt/entrypoints/openai/serving_classify.py +204 -0
 - sglang/srt/entrypoints/openai/serving_completions.py +15 -1
 - sglang/srt/entrypoints/openai/serving_responses.py +5 -2
 - sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
 - sglang/srt/environ.py +58 -6
 - sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
 - sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
 - sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
 - sglang/srt/eplb/expert_distribution.py +33 -4
 - sglang/srt/eplb/expert_location_dispatch.py +2 -2
 - sglang/srt/eplb/expert_location_updater.py +2 -2
 - sglang/srt/function_call/base_format_detector.py +17 -18
 - sglang/srt/function_call/function_call_parser.py +20 -14
 - sglang/srt/function_call/glm4_moe_detector.py +1 -5
 - sglang/srt/function_call/gpt_oss_detector.py +1 -1
 - sglang/srt/function_call/json_array_parser.py +0 -2
 - sglang/srt/function_call/minimax_m2.py +367 -0
 - sglang/srt/function_call/utils.py +2 -2
 - sglang/srt/grpc/compile_proto.py +3 -3
 - sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
 - sglang/srt/grpc/health_servicer.py +189 -0
 - sglang/srt/grpc/scheduler_launcher.py +181 -0
 - sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
 - sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
 - sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
 - sglang/srt/layers/activation.py +10 -1
 - sglang/srt/layers/attention/aiter_backend.py +3 -3
 - sglang/srt/layers/attention/ascend_backend.py +17 -1
 - sglang/srt/layers/attention/attention_registry.py +43 -23
 - sglang/srt/layers/attention/base_attn_backend.py +20 -1
 - sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
 - sglang/srt/layers/attention/fla/chunk.py +0 -1
 - sglang/srt/layers/attention/fla/chunk_o.py +1 -1
 - sglang/srt/layers/attention/fla/index.py +0 -2
 - sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
 - sglang/srt/layers/attention/fla/utils.py +0 -3
 - sglang/srt/layers/attention/fla/wy_fast.py +0 -2
 - sglang/srt/layers/attention/flashattention_backend.py +24 -10
 - sglang/srt/layers/attention/flashinfer_backend.py +258 -22
 - sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
 - sglang/srt/layers/attention/flashmla_backend.py +2 -2
 - sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
 - sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
 - sglang/srt/layers/attention/intel_amx_backend.py +1 -1
 - sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
 - sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
 - sglang/srt/layers/attention/mamba/mamba.py +189 -241
 - sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
 - sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
 - sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
 - sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
 - sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
 - sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
 - sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
 - sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
 - sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
 - sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
 - sglang/srt/layers/attention/nsa/utils.py +0 -1
 - sglang/srt/layers/attention/nsa_backend.py +404 -90
 - sglang/srt/layers/attention/triton_backend.py +208 -34
 - sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
 - sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
 - sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
 - sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
 - sglang/srt/layers/attention/utils.py +89 -7
 - sglang/srt/layers/attention/vision.py +3 -3
 - sglang/srt/layers/attention/xpu_backend.py +1028 -0
 - sglang/srt/layers/communicator.py +12 -7
 - sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
 - sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
 - sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
 - sglang/srt/layers/dp_attention.py +17 -0
 - sglang/srt/layers/layernorm.py +64 -19
 - sglang/srt/layers/linear.py +9 -1
 - sglang/srt/layers/logits_processor.py +152 -17
 - sglang/srt/layers/modelopt_utils.py +11 -0
 - sglang/srt/layers/moe/cutlass_moe.py +0 -2
 - sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
 - sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
 - sglang/srt/layers/moe/ep_moe/layer.py +154 -625
 - sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
 - sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
 - sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
 - sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
 - sglang/srt/layers/moe/moe_runner/runner.py +6 -0
 - sglang/srt/layers/moe/moe_runner/triton.py +3 -1
 - sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
 - sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
 - sglang/srt/layers/moe/router.py +51 -15
 - sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
 - sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
 - sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
 - sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
 - sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
 - sglang/srt/layers/moe/topk.py +7 -6
 - sglang/srt/layers/moe/utils.py +20 -5
 - sglang/srt/layers/quantization/__init__.py +5 -58
 - sglang/srt/layers/quantization/awq.py +183 -9
 - sglang/srt/layers/quantization/awq_triton.py +29 -0
 - sglang/srt/layers/quantization/base_config.py +27 -1
 - sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
 - sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
 - sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
 - sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
 - sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
 - sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
 - sglang/srt/layers/quantization/fp8.py +152 -81
 - sglang/srt/layers/quantization/fp8_kernel.py +55 -10
 - sglang/srt/layers/quantization/fp8_utils.py +42 -14
 - sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
 - sglang/srt/layers/quantization/gguf.py +566 -0
 - sglang/srt/layers/quantization/gptq.py +0 -1
 - sglang/srt/layers/quantization/int8_kernel.py +18 -2
 - sglang/srt/layers/quantization/marlin_utils.py +12 -0
 - sglang/srt/layers/quantization/modelopt_quant.py +125 -100
 - sglang/srt/layers/quantization/mxfp4.py +35 -68
 - sglang/srt/layers/quantization/petit.py +1 -1
 - sglang/srt/layers/quantization/quark/quark.py +3 -1
 - sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
 - sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
 - sglang/srt/layers/quantization/unquant.py +23 -48
 - sglang/srt/layers/quantization/utils.py +0 -1
 - sglang/srt/layers/quantization/w4afp8.py +87 -20
 - sglang/srt/layers/quantization/w8a8_int8.py +30 -24
 - sglang/srt/layers/radix_attention.py +62 -9
 - sglang/srt/layers/rotary_embedding.py +686 -17
 - sglang/srt/layers/sampler.py +47 -16
 - sglang/srt/layers/sparse_pooler.py +98 -0
 - sglang/srt/layers/utils.py +0 -1
 - sglang/srt/layers/vocab_parallel_embedding.py +4 -1
 - sglang/srt/lora/backend/triton_backend.py +0 -1
 - sglang/srt/lora/eviction_policy.py +139 -0
 - sglang/srt/lora/lora_manager.py +24 -9
 - sglang/srt/lora/lora_registry.py +1 -1
 - sglang/srt/lora/mem_pool.py +40 -16
 - sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
 - sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
 - sglang/srt/managers/cache_controller.py +48 -17
 - sglang/srt/managers/data_parallel_controller.py +146 -42
 - sglang/srt/managers/detokenizer_manager.py +40 -13
 - sglang/srt/managers/io_struct.py +69 -16
 - sglang/srt/managers/mm_utils.py +20 -18
 - sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
 - sglang/srt/managers/overlap_utils.py +96 -19
 - sglang/srt/managers/schedule_batch.py +241 -511
 - sglang/srt/managers/schedule_policy.py +15 -2
 - sglang/srt/managers/scheduler.py +420 -514
 - sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
 - sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
 - sglang/srt/managers/scheduler_pp_mixin.py +341 -0
 - sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
 - sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
 - sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
 - sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
 - sglang/srt/managers/tokenizer_manager.py +375 -95
 - sglang/srt/managers/tp_worker.py +212 -161
 - sglang/srt/managers/utils.py +78 -2
 - sglang/srt/mem_cache/allocator.py +7 -2
 - sglang/srt/mem_cache/allocator_ascend.py +2 -2
 - sglang/srt/mem_cache/base_prefix_cache.py +2 -2
 - sglang/srt/mem_cache/chunk_cache.py +13 -2
 - sglang/srt/mem_cache/common.py +480 -0
 - sglang/srt/mem_cache/evict_policy.py +16 -1
 - sglang/srt/mem_cache/hicache_storage.py +11 -2
 - sglang/srt/mem_cache/hiradix_cache.py +16 -3
 - sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
 - sglang/srt/mem_cache/memory_pool.py +517 -219
 - sglang/srt/mem_cache/memory_pool_host.py +0 -1
 - sglang/srt/mem_cache/multimodal_cache.py +0 -1
 - sglang/srt/mem_cache/radix_cache.py +53 -19
 - sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
 - sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
 - sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
 - sglang/srt/mem_cache/storage/backend_factory.py +2 -2
 - sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
 - sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
 - sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
 - sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
 - sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
 - sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
 - sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
 - sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
 - sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
 - sglang/srt/mem_cache/swa_radix_cache.py +92 -26
 - sglang/srt/metrics/collector.py +31 -0
 - sglang/srt/metrics/func_timer.py +1 -1
 - sglang/srt/model_executor/cuda_graph_runner.py +43 -5
 - sglang/srt/model_executor/forward_batch_info.py +71 -25
 - sglang/srt/model_executor/model_runner.py +362 -270
 - sglang/srt/model_executor/npu_graph_runner.py +2 -3
 - sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
 - sglang/srt/model_loader/__init__.py +1 -1
 - sglang/srt/model_loader/loader.py +424 -27
 - sglang/srt/model_loader/utils.py +0 -1
 - sglang/srt/model_loader/weight_utils.py +47 -28
 - sglang/srt/models/apertus.py +2 -3
 - sglang/srt/models/arcee.py +2 -2
 - sglang/srt/models/bailing_moe.py +13 -52
 - sglang/srt/models/bailing_moe_nextn.py +3 -4
 - sglang/srt/models/bert.py +1 -1
 - sglang/srt/models/deepseek_nextn.py +19 -3
 - sglang/srt/models/deepseek_ocr.py +1516 -0
 - sglang/srt/models/deepseek_v2.py +418 -140
 - sglang/srt/models/dots_ocr.py +0 -2
 - sglang/srt/models/dots_vlm.py +0 -1
 - sglang/srt/models/dots_vlm_vit.py +1 -1
 - sglang/srt/models/falcon_h1.py +13 -19
 - sglang/srt/models/gemma3_mm.py +16 -0
 - sglang/srt/models/gemma3n_mm.py +1 -2
 - sglang/srt/models/glm4_moe.py +327 -382
 - sglang/srt/models/glm4_moe_nextn.py +6 -16
 - sglang/srt/models/glm4v.py +2 -1
 - sglang/srt/models/glm4v_moe.py +32 -199
 - sglang/srt/models/gpt_oss.py +5 -5
 - sglang/srt/models/grok.py +10 -23
 - sglang/srt/models/hunyuan.py +2 -7
 - sglang/srt/models/interns1.py +0 -1
 - sglang/srt/models/kimi_vl.py +1 -7
 - sglang/srt/models/kimi_vl_moonvit.py +3 -1
 - sglang/srt/models/llama.py +2 -2
 - sglang/srt/models/llama_eagle3.py +1 -1
 - sglang/srt/models/longcat_flash.py +5 -22
 - sglang/srt/models/longcat_flash_nextn.py +3 -14
 - sglang/srt/models/mimo.py +2 -13
 - sglang/srt/models/mimo_mtp.py +1 -2
 - sglang/srt/models/minicpmo.py +7 -5
 - sglang/srt/models/minimax_m2.py +922 -0
 - sglang/srt/models/mixtral.py +1 -4
 - sglang/srt/models/mllama.py +1 -1
 - sglang/srt/models/mllama4.py +13 -3
 - sglang/srt/models/nemotron_h.py +511 -0
 - sglang/srt/models/nvila.py +355 -0
 - sglang/srt/models/nvila_lite.py +184 -0
 - sglang/srt/models/olmo2.py +31 -4
 - sglang/srt/models/opt.py +5 -5
 - sglang/srt/models/phi.py +1 -1
 - sglang/srt/models/phi4mm.py +1 -1
 - sglang/srt/models/phimoe.py +0 -1
 - sglang/srt/models/pixtral.py +0 -3
 - sglang/srt/models/points_v15_chat.py +186 -0
 - sglang/srt/models/qwen.py +0 -1
 - sglang/srt/models/qwen2.py +22 -1
 - sglang/srt/models/qwen2_5_vl.py +3 -3
 - sglang/srt/models/qwen2_audio.py +2 -15
 - sglang/srt/models/qwen2_moe.py +15 -12
 - sglang/srt/models/qwen2_vl.py +5 -2
 - sglang/srt/models/qwen3.py +34 -4
 - sglang/srt/models/qwen3_moe.py +19 -37
 - sglang/srt/models/qwen3_next.py +7 -12
 - sglang/srt/models/qwen3_next_mtp.py +3 -4
 - sglang/srt/models/qwen3_omni_moe.py +661 -0
 - sglang/srt/models/qwen3_vl.py +37 -33
 - sglang/srt/models/qwen3_vl_moe.py +57 -185
 - sglang/srt/models/roberta.py +55 -3
 - sglang/srt/models/sarashina2_vision.py +0 -1
 - sglang/srt/models/step3_vl.py +3 -5
 - sglang/srt/models/utils.py +11 -1
 - sglang/srt/multimodal/processors/base_processor.py +7 -2
 - sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
 - sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
 - sglang/srt/multimodal/processors/dots_vlm.py +0 -1
 - sglang/srt/multimodal/processors/glm4v.py +2 -6
 - sglang/srt/multimodal/processors/internvl.py +0 -2
 - sglang/srt/multimodal/processors/janus_pro.py +0 -1
 - sglang/srt/multimodal/processors/mllama4.py +0 -8
 - sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
 - sglang/srt/multimodal/processors/phi4mm.py +0 -1
 - sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
 - sglang/srt/multimodal/processors/qwen_vl.py +75 -16
 - sglang/srt/multimodal/processors/step3_vl.py +1 -1
 - sglang/srt/parser/conversation.py +41 -0
 - sglang/srt/parser/reasoning_parser.py +28 -2
 - sglang/srt/sampling/custom_logit_processor.py +77 -2
 - sglang/srt/sampling/sampling_batch_info.py +17 -22
 - sglang/srt/sampling/sampling_params.py +70 -2
 - sglang/srt/server_args.py +846 -163
 - sglang/srt/server_args_config_parser.py +1 -1
 - sglang/srt/single_batch_overlap.py +36 -31
 - sglang/srt/speculative/base_spec_worker.py +34 -0
 - sglang/srt/speculative/draft_utils.py +226 -0
 - sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
 - sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
 - sglang/srt/speculative/eagle_info.py +57 -18
 - sglang/srt/speculative/eagle_info_v2.py +458 -0
 - sglang/srt/speculative/eagle_utils.py +138 -0
 - sglang/srt/speculative/eagle_worker.py +83 -280
 - sglang/srt/speculative/eagle_worker_v2.py +702 -0
 - sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
 - sglang/srt/speculative/ngram_worker.py +12 -11
 - sglang/srt/speculative/spec_info.py +2 -0
 - sglang/srt/speculative/spec_utils.py +38 -3
 - sglang/srt/speculative/standalone_worker.py +4 -14
 - sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
 - sglang/srt/two_batch_overlap.py +28 -14
 - sglang/srt/utils/__init__.py +1 -1
 - sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
 - sglang/srt/utils/common.py +272 -82
 - sglang/srt/utils/hf_transformers_utils.py +44 -17
 - sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
 - sglang/srt/{offloader.py → utils/offloader.py} +4 -4
 - sglang/srt/utils/profile_merger.py +199 -0
 - sglang/test/attention/test_flashattn_backend.py +1 -1
 - sglang/test/attention/test_flashattn_mla_backend.py +0 -1
 - sglang/test/attention/test_prefix_chunk_info.py +0 -2
 - sglang/test/attention/test_trtllm_mla_backend.py +221 -53
 - sglang/test/few_shot_gsm8k_engine.py +2 -4
 - sglang/test/kit_matched_stop.py +157 -0
 - sglang/test/longbench_v2/__init__.py +1 -0
 - sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
 - sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
 - sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
 - sglang/test/run_eval.py +41 -0
 - sglang/test/runners.py +2 -0
 - sglang/test/send_one.py +42 -7
 - sglang/test/simple_eval_common.py +3 -0
 - sglang/test/simple_eval_gpqa.py +0 -1
 - sglang/test/simple_eval_humaneval.py +0 -3
 - sglang/test/simple_eval_longbench_v2.py +344 -0
 - sglang/test/test_block_fp8.py +1 -2
 - sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
 - sglang/test/test_cutlass_moe.py +1 -2
 - sglang/test/test_cutlass_w4a8_moe.py +10 -20
 - sglang/test/test_deterministic.py +463 -107
 - sglang/test/test_deterministic_utils.py +74 -0
 - sglang/test/test_disaggregation_utils.py +81 -0
 - sglang/test/test_marlin_moe.py +0 -1
 - sglang/test/test_utils.py +85 -20
 - sglang/version.py +1 -1
 - {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
 - {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
 - sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
 - sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
 - sglang/srt/models/vila.py +0 -306
 - sglang/srt/speculative/build_eagle_tree.py +0 -427
 - sglang/test/test_block_fp8_ep.py +0 -358
 - /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
 - /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
 - /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
 - {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
 - {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
 - {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
 
    
        sglang/srt/layers/sampler.py
    CHANGED
    
    | 
         @@ -11,8 +11,8 @@ from sglang.srt.layers.dp_attention import ( 
     | 
|
| 
       11 
11 
     | 
    
         
             
                is_dp_attention_enabled,
         
     | 
| 
       12 
12 
     | 
    
         
             
            )
         
     | 
| 
       13 
13 
     | 
    
         
             
            from sglang.srt.layers.logits_processor import LogitsProcessorOutput
         
     | 
| 
       14 
     | 
    
         
            -
            from sglang.srt.managers.schedule_batch import global_server_args_dict
         
     | 
| 
       15 
14 
     | 
    
         
             
            from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
         
     | 
| 
      
 15 
     | 
    
         
            +
            from sglang.srt.server_args import get_global_server_args
         
     | 
| 
       16 
16 
     | 
    
         
             
            from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda
         
     | 
| 
       17 
17 
     | 
    
         | 
| 
       18 
18 
     | 
    
         
             
            if is_cuda():
         
     | 
| 
         @@ -27,13 +27,13 @@ if is_cuda(): 
     | 
|
| 
       27 
27 
     | 
    
         
             
            logger = logging.getLogger(__name__)
         
     | 
| 
       28 
28 
     | 
    
         | 
| 
       29 
29 
     | 
    
         
             
            SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP")
         
     | 
| 
       30 
     | 
    
         
            -
             
     | 
| 
      
 30 
     | 
    
         
            +
            SGLANG_RETURN_ORIGINAL_LOGPROB = get_bool_env_var("SGLANG_RETURN_ORIGINAL_LOGPROB")
         
     | 
| 
       31 
31 
     | 
    
         | 
| 
       32 
32 
     | 
    
         | 
| 
       33 
33 
     | 
    
         
             
            class Sampler(nn.Module):
         
     | 
| 
       34 
34 
     | 
    
         
             
                def __init__(self):
         
     | 
| 
       35 
35 
     | 
    
         
             
                    super().__init__()
         
     | 
| 
       36 
     | 
    
         
            -
                    self.use_nan_detection =  
     | 
| 
      
 36 
     | 
    
         
            +
                    self.use_nan_detection = get_global_server_args().enable_nan_detection
         
     | 
| 
       37 
37 
     | 
    
         
             
                    self.tp_sync_group = get_tp_group().device_group
         
     | 
| 
       38 
38 
     | 
    
         | 
| 
       39 
39 
     | 
    
         
             
                    if is_dp_attention_enabled():
         
     | 
| 
         @@ -91,20 +91,40 @@ class Sampler(nn.Module): 
     | 
|
| 
       91 
91 
     | 
    
         
             
                        batch_next_token_ids = torch.argmax(logits, -1)
         
     | 
| 
       92 
92 
     | 
    
         
             
                        if return_logprob:
         
     | 
| 
       93 
93 
     | 
    
         
             
                            logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
         
     | 
| 
       94 
     | 
    
         
            -
             
     | 
| 
       95 
94 
     | 
    
         
             
                    else:
         
     | 
| 
      
 95 
     | 
    
         
            +
                        can_sample_directly_from_probs = (
         
     | 
| 
      
 96 
     | 
    
         
            +
                            not sampling_info.need_top_p_sampling
         
     | 
| 
      
 97 
     | 
    
         
            +
                            and not sampling_info.need_top_k_sampling
         
     | 
| 
      
 98 
     | 
    
         
            +
                            and not sampling_info.need_min_p_sampling
         
     | 
| 
      
 99 
     | 
    
         
            +
                        )
         
     | 
| 
      
 100 
     | 
    
         
            +
             
     | 
| 
       96 
101 
     | 
    
         
             
                        # If requested, cache probabilities from original logits before temperature scaling.
         
     | 
| 
       97 
     | 
    
         
            -
                        if return_logprob and  
     | 
| 
      
 102 
     | 
    
         
            +
                        if return_logprob and SGLANG_RETURN_ORIGINAL_LOGPROB:
         
     | 
| 
       98 
103 
     | 
    
         
             
                            probs_without_temp_scaling = torch.softmax(logits, dim=-1)
         
     | 
| 
       99 
104 
     | 
    
         | 
| 
      
 105 
     | 
    
         
            +
                        if get_global_server_args().rl_on_policy_target == "fsdp":
         
     | 
| 
      
 106 
     | 
    
         
            +
                            logits_div_temperature = (
         
     | 
| 
      
 107 
     | 
    
         
            +
                                logits.bfloat16().div(sampling_info.temperatures).bfloat16()
         
     | 
| 
      
 108 
     | 
    
         
            +
                            )
         
     | 
| 
      
 109 
     | 
    
         
            +
                            logprobs_via_logsoftmax_kernel = torch.log_softmax(
         
     | 
| 
      
 110 
     | 
    
         
            +
                                logits_div_temperature, dim=-1
         
     | 
| 
      
 111 
     | 
    
         
            +
                            )
         
     | 
| 
      
 112 
     | 
    
         
            +
             
     | 
| 
       100 
113 
     | 
    
         
             
                        # Post process logits
         
     | 
| 
       101 
114 
     | 
    
         
             
                        logits.div_(sampling_info.temperatures)
         
     | 
| 
       102 
115 
     | 
    
         
             
                        logits[:] = torch.softmax(logits, dim=-1)
         
     | 
| 
       103 
116 
     | 
    
         
             
                        probs = logits
         
     | 
| 
       104 
117 
     | 
    
         
             
                        del logits
         
     | 
| 
       105 
118 
     | 
    
         | 
| 
       106 
     | 
    
         
            -
                        if  
     | 
| 
       107 
     | 
    
         
            -
                             
     | 
| 
      
 119 
     | 
    
         
            +
                        if can_sample_directly_from_probs:
         
     | 
| 
      
 120 
     | 
    
         
            +
                            # when we don't need top-k, top-p, or min-p sampling, we can directly sample from the probs
         
     | 
| 
      
 121 
     | 
    
         
            +
                            batch_next_token_ids = sampling_from_probs_torch(
         
     | 
| 
      
 122 
     | 
    
         
            +
                                probs,
         
     | 
| 
      
 123 
     | 
    
         
            +
                                sampling_seed=sampling_info.sampling_seed,
         
     | 
| 
      
 124 
     | 
    
         
            +
                                positions=positions,
         
     | 
| 
      
 125 
     | 
    
         
            +
                            )
         
     | 
| 
      
 126 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 127 
     | 
    
         
            +
                            if get_global_server_args().sampling_backend == "flashinfer":
         
     | 
| 
       108 
128 
     | 
    
         
             
                                if sampling_info.need_min_p_sampling:
         
     | 
| 
       109 
129 
     | 
    
         
             
                                    probs = top_k_renorm_prob(probs, sampling_info.top_ks)
         
     | 
| 
       110 
130 
     | 
    
         
             
                                    probs = top_p_renorm_prob(probs, sampling_info.top_ps)
         
     | 
| 
         @@ -119,7 +139,7 @@ class Sampler(nn.Module): 
     | 
|
| 
       119 
139 
     | 
    
         
             
                                        filter_apply_order="joint",
         
     | 
| 
       120 
140 
     | 
    
         
             
                                        check_nan=self.use_nan_detection,
         
     | 
| 
       121 
141 
     | 
    
         
             
                                    )
         
     | 
| 
       122 
     | 
    
         
            -
                            elif  
     | 
| 
      
 142 
     | 
    
         
            +
                            elif get_global_server_args().sampling_backend == "pytorch":
         
     | 
| 
       123 
143 
     | 
    
         
             
                                # A slower fallback implementation with torch native operations.
         
     | 
| 
       124 
144 
     | 
    
         
             
                                batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
         
     | 
| 
       125 
145 
     | 
    
         
             
                                    probs,
         
     | 
| 
         @@ -132,12 +152,15 @@ class Sampler(nn.Module): 
     | 
|
| 
       132 
152 
     | 
    
         
             
                                )
         
     | 
| 
       133 
153 
     | 
    
         
             
                            else:
         
     | 
| 
       134 
154 
     | 
    
         
             
                                raise ValueError(
         
     | 
| 
       135 
     | 
    
         
            -
                                    f"Invalid sampling backend: { 
     | 
| 
      
 155 
     | 
    
         
            +
                                    f"Invalid sampling backend: {get_global_server_args().sampling_backend}"
         
     | 
| 
       136 
156 
     | 
    
         
             
                                )
         
     | 
| 
       137 
157 
     | 
    
         | 
| 
       138 
158 
     | 
    
         
             
                        if return_logprob:
         
     | 
| 
      
 159 
     | 
    
         
            +
                            if get_global_server_args().rl_on_policy_target == "fsdp":
         
     | 
| 
      
 160 
     | 
    
         
            +
                                logprobs = logprobs_via_logsoftmax_kernel
         
     | 
| 
      
 161 
     | 
    
         
            +
                                del logprobs_via_logsoftmax_kernel
         
     | 
| 
       139 
162 
     | 
    
         
             
                            # clamp to avoid -inf
         
     | 
| 
       140 
     | 
    
         
            -
                             
     | 
| 
      
 163 
     | 
    
         
            +
                            elif SGLANG_RETURN_ORIGINAL_LOGPROB:
         
     | 
| 
       141 
164 
     | 
    
         
             
                                logprobs = torch.log(probs_without_temp_scaling).clamp(
         
     | 
| 
       142 
165 
     | 
    
         
             
                                    min=torch.finfo(probs_without_temp_scaling.dtype).min
         
     | 
| 
       143 
166 
     | 
    
         
             
                                )
         
     | 
| 
         @@ -288,21 +311,29 @@ def multinomial_with_seed( 
     | 
|
| 
       288 
311 
     | 
    
         
             
                """
         
     | 
| 
       289 
312 
     | 
    
         
             
                n, m = inputs.shape
         
     | 
| 
       290 
313 
     | 
    
         
             
                col_indices = torch.arange(m, device=inputs.device).unsqueeze(0)
         
     | 
| 
       291 
     | 
    
         
            -
                step_seed = seed * 19349663 ^ positions * 73856093
         
     | 
| 
      
 314 
     | 
    
         
            +
                step_seed = (seed * 19349663) ^ (positions * 73856093)
         
     | 
| 
       292 
315 
     | 
    
         
             
                seed_expanded = step_seed.unsqueeze(-1)
         
     | 
| 
       293 
     | 
    
         
            -
                hashed = seed_expanded * 8589934591 ^ col_indices * 479001599
         
     | 
| 
      
 316 
     | 
    
         
            +
                hashed = (seed_expanded * 8589934591) ^ (col_indices * 479001599)
         
     | 
| 
       294 
317 
     | 
    
         
             
                uniform_samples = (hashed % (2**24)).float() / (2**24)
         
     | 
| 
       295 
     | 
    
         
            -
                epsilon = 1e- 
     | 
| 
       296 
     | 
    
         
            -
                 
     | 
| 
      
 318 
     | 
    
         
            +
                epsilon = 1e-10
         
     | 
| 
      
 319 
     | 
    
         
            +
                uniform_samples = uniform_samples.clamp(epsilon, 1.0 - epsilon)
         
     | 
| 
      
 320 
     | 
    
         
            +
                gumbel_noise = -torch.log(-torch.log(uniform_samples))
         
     | 
| 
       297 
321 
     | 
    
         
             
                log_probs = torch.log(inputs + epsilon)
         
     | 
| 
       298 
322 
     | 
    
         
             
                perturbed_log_probs = log_probs + gumbel_noise
         
     | 
| 
       299 
323 
     | 
    
         
             
                return torch.argmax(perturbed_log_probs, dim=1, keepdim=True)
         
     | 
| 
       300 
324 
     | 
    
         | 
| 
       301 
325 
     | 
    
         | 
| 
       302 
     | 
    
         
            -
            def sampling_from_probs_torch( 
     | 
| 
      
 326 
     | 
    
         
            +
            def sampling_from_probs_torch(
         
     | 
| 
      
 327 
     | 
    
         
            +
                probs: torch.Tensor,
         
     | 
| 
      
 328 
     | 
    
         
            +
                sampling_seed: Optional[torch.Tensor] = None,
         
     | 
| 
      
 329 
     | 
    
         
            +
                positions: Optional[torch.Tensor] = None,
         
     | 
| 
      
 330 
     | 
    
         
            +
            ):
         
     | 
| 
       303 
331 
     | 
    
         
             
                """A sampling implementation with native pytorch operations, without
         
     | 
| 
       304 
332 
     | 
    
         
             
                top-k, top-p, or min-p filtering."""
         
     | 
| 
       305 
     | 
    
         
            -
                 
     | 
| 
      
 333 
     | 
    
         
            +
                if sampling_seed is not None:
         
     | 
| 
      
 334 
     | 
    
         
            +
                    sampled_index = multinomial_with_seed(probs, sampling_seed, positions)
         
     | 
| 
      
 335 
     | 
    
         
            +
                else:
         
     | 
| 
      
 336 
     | 
    
         
            +
                    sampled_index = torch.multinomial(probs, num_samples=1)
         
     | 
| 
       306 
337 
     | 
    
         
             
                batch_next_token_ids = sampled_index.view(-1).to(torch.int32)
         
     | 
| 
       307 
338 
     | 
    
         
             
                return batch_next_token_ids
         
     | 
| 
       308 
339 
     | 
    
         | 
| 
         @@ -0,0 +1,98 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            from dataclasses import dataclass
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            import torch
         
     | 
| 
      
 4 
     | 
    
         
            +
            import torch.nn as nn
         
     | 
| 
      
 5 
     | 
    
         
            +
            import torch.nn.functional as F
         
     | 
| 
      
 6 
     | 
    
         
            +
            from transformers import PretrainedConfig
         
     | 
| 
      
 7 
     | 
    
         
            +
             
     | 
| 
      
 8 
     | 
    
         
            +
            from sglang.srt.model_executor.model_runner import ForwardBatch
         
     | 
| 
      
 9 
     | 
    
         
            +
             
     | 
| 
      
 10 
     | 
    
         
            +
             
     | 
| 
      
 11 
     | 
    
         
            +
            @dataclass
         
     | 
| 
      
 12 
     | 
    
         
            +
            class SparseEmbeddingOutput:
         
     | 
| 
      
 13 
     | 
    
         
            +
                embeddings: torch.Tensor  # [batch_size, vocab_size]
         
     | 
| 
      
 14 
     | 
    
         
            +
             
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
            class SparsePooler(nn.Module):
         
     | 
| 
      
 17 
     | 
    
         
            +
                """A layer that pools hidden states into sparse vocabulary-space embeddings.
         
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
      
 19 
     | 
    
         
            +
                This layer does the following:
         
     | 
| 
      
 20 
     | 
    
         
            +
                1. Applies a linear transformation + ReLU to get token-level weights
         
     | 
| 
      
 21 
     | 
    
         
            +
                2. Maps these weights to vocabulary positions using token IDs
         
     | 
| 
      
 22 
     | 
    
         
            +
                3. Aggregates weights for repeated tokens using max pooling
         
     | 
| 
      
 23 
     | 
    
         
            +
                4. Returns sparse embeddings in vocabulary space
         
     | 
| 
      
 24 
     | 
    
         
            +
             
     | 
| 
      
 25 
     | 
    
         
            +
                Attributes:
         
     | 
| 
      
 26 
     | 
    
         
            +
                    config: Model configuration containing vocab_size and hidden_size
         
     | 
| 
      
 27 
     | 
    
         
            +
                    sparse_linear: Linear layer for computing token weights
         
     | 
| 
      
 28 
     | 
    
         
            +
                    vocab_size: Size of vocabulary for output embeddings
         
     | 
| 
      
 29 
     | 
    
         
            +
                """
         
     | 
| 
      
 30 
     | 
    
         
            +
             
     | 
| 
      
 31 
     | 
    
         
            +
                def __init__(self, config: PretrainedConfig):
         
     | 
| 
      
 32 
     | 
    
         
            +
                    super().__init__()
         
     | 
| 
      
 33 
     | 
    
         
            +
             
     | 
| 
      
 34 
     | 
    
         
            +
                    # Validate required attributes
         
     | 
| 
      
 35 
     | 
    
         
            +
                    if not hasattr(config, "vocab_size"):
         
     | 
| 
      
 36 
     | 
    
         
            +
                        raise AttributeError(
         
     | 
| 
      
 37 
     | 
    
         
            +
                            f"Config {type(config)} missing required 'vocab_size' attribute"
         
     | 
| 
      
 38 
     | 
    
         
            +
                        )
         
     | 
| 
      
 39 
     | 
    
         
            +
                    if not hasattr(config, "hidden_size"):
         
     | 
| 
      
 40 
     | 
    
         
            +
                        raise AttributeError(
         
     | 
| 
      
 41 
     | 
    
         
            +
                            f"Config {type(config)} missing required 'hidden_size' attribute"
         
     | 
| 
      
 42 
     | 
    
         
            +
                        )
         
     | 
| 
      
 43 
     | 
    
         
            +
             
     | 
| 
      
 44 
     | 
    
         
            +
                    self.vocab_size = config.vocab_size
         
     | 
| 
      
 45 
     | 
    
         
            +
                    self.sparse_linear = nn.Linear(config.hidden_size, 1)
         
     | 
| 
      
 46 
     | 
    
         
            +
                    self._weights_loaded = False
         
     | 
| 
      
 47 
     | 
    
         
            +
             
     | 
| 
      
 48 
     | 
    
         
            +
                def forward(
         
     | 
| 
      
 49 
     | 
    
         
            +
                    self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
         
     | 
| 
      
 50 
     | 
    
         
            +
                ) -> SparseEmbeddingOutput:
         
     | 
| 
      
 51 
     | 
    
         
            +
                    """
         
     | 
| 
      
 52 
     | 
    
         
            +
                    Forward pass for sparse pooling.
         
     | 
| 
      
 53 
     | 
    
         
            +
             
     | 
| 
      
 54 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 55 
     | 
    
         
            +
                        hidden_states: Packed sequence hidden states [total_tokens, hidden_size]
         
     | 
| 
      
 56 
     | 
    
         
            +
                        forward_batch: Batch information with sequence lengths and input_ids
         
     | 
| 
      
 57 
     | 
    
         
            +
             
     | 
| 
      
 58 
     | 
    
         
            +
                    Returns:
         
     | 
| 
      
 59 
     | 
    
         
            +
                        SparseEmbeddingOutput with embeddings of shape [batch_size, vocab_size]
         
     | 
| 
      
 60 
     | 
    
         
            +
                    """
         
     | 
| 
      
 61 
     | 
    
         
            +
                    if not self._weights_loaded:
         
     | 
| 
      
 62 
     | 
    
         
            +
                        raise ValueError(
         
     | 
| 
      
 63 
     | 
    
         
            +
                            "Sparse pooling weights not loaded. Call load_weights() first"
         
     | 
| 
      
 64 
     | 
    
         
            +
                        )
         
     | 
| 
      
 65 
     | 
    
         
            +
             
     | 
| 
      
 66 
     | 
    
         
            +
                    # Apply sparse linear + ReLU to get token weights
         
     | 
| 
      
 67 
     | 
    
         
            +
                    token_weights = F.relu(self.sparse_linear(hidden_states)).squeeze(
         
     | 
| 
      
 68 
     | 
    
         
            +
                        -1
         
     | 
| 
      
 69 
     | 
    
         
            +
                    )  # [total_tokens]
         
     | 
| 
      
 70 
     | 
    
         
            +
             
     | 
| 
      
 71 
     | 
    
         
            +
                    # Create batch indices for packed sequences
         
     | 
| 
      
 72 
     | 
    
         
            +
                    batch_indices = torch.repeat_interleave(
         
     | 
| 
      
 73 
     | 
    
         
            +
                        torch.arange(
         
     | 
| 
      
 74 
     | 
    
         
            +
                            len(forward_batch.extend_seq_lens), device=hidden_states.device
         
     | 
| 
      
 75 
     | 
    
         
            +
                        ),
         
     | 
| 
      
 76 
     | 
    
         
            +
                        forward_batch.extend_seq_lens,
         
     | 
| 
      
 77 
     | 
    
         
            +
                    )
         
     | 
| 
      
 78 
     | 
    
         
            +
             
     | 
| 
      
 79 
     | 
    
         
            +
                    # Initialize sparse embedding output
         
     | 
| 
      
 80 
     | 
    
         
            +
                    sparse_embedding = torch.zeros(
         
     | 
| 
      
 81 
     | 
    
         
            +
                        len(forward_batch.extend_seq_lens),
         
     | 
| 
      
 82 
     | 
    
         
            +
                        self.vocab_size,
         
     | 
| 
      
 83 
     | 
    
         
            +
                        dtype=token_weights.dtype,
         
     | 
| 
      
 84 
     | 
    
         
            +
                        device=token_weights.device,
         
     | 
| 
      
 85 
     | 
    
         
            +
                    )
         
     | 
| 
      
 86 
     | 
    
         
            +
             
     | 
| 
      
 87 
     | 
    
         
            +
                    # Map to vocabulary space using scatter_reduce with amax
         
     | 
| 
      
 88 
     | 
    
         
            +
                    flat_indices = batch_indices * self.vocab_size + forward_batch.input_ids
         
     | 
| 
      
 89 
     | 
    
         
            +
                    sparse_embedding.view(-1).scatter_reduce_(
         
     | 
| 
      
 90 
     | 
    
         
            +
                        0, flat_indices, token_weights, reduce="amax"
         
     | 
| 
      
 91 
     | 
    
         
            +
                    )
         
     | 
| 
      
 92 
     | 
    
         
            +
             
     | 
| 
      
 93 
     | 
    
         
            +
                    return SparseEmbeddingOutput(embeddings=sparse_embedding)
         
     | 
| 
      
 94 
     | 
    
         
            +
             
     | 
| 
      
 95 
     | 
    
         
            +
                def load_weights(self, state_dict: dict):
         
     | 
| 
      
 96 
     | 
    
         
            +
                    """Load weights from state dict (called by the model)."""
         
     | 
| 
      
 97 
     | 
    
         
            +
                    self.sparse_linear.load_state_dict(state_dict)
         
     | 
| 
      
 98 
     | 
    
         
            +
                    self._weights_loaded = True
         
     | 
    
        sglang/srt/layers/utils.py
    CHANGED
    
    
| 
         @@ -540,7 +540,10 @@ class ParallelLMHead(VocabParallelEmbedding): 
     | 
|
| 
       540 
540 
     | 
    
         | 
| 
       541 
541 
     | 
    
         
             
                    # We only support pack LMHead if it's not quantized.
         
     | 
| 
       542 
542 
     | 
    
         
             
                    if _is_cpu and _is_cpu_amx_available:
         
     | 
| 
       543 
     | 
    
         
            -
                        if hasattr(self, "weight") and self.weight.dtype  
     | 
| 
      
 543 
     | 
    
         
            +
                        if hasattr(self, "weight") and self.weight.dtype in [
         
     | 
| 
      
 544 
     | 
    
         
            +
                            torch.bfloat16,
         
     | 
| 
      
 545 
     | 
    
         
            +
                            torch.float16,
         
     | 
| 
      
 546 
     | 
    
         
            +
                        ]:
         
     | 
| 
       544 
547 
     | 
    
         
             
                            self.quant_method = PackWeightMethod(weight_names=["weight"])
         
     | 
| 
       545 
548 
     | 
    
         | 
| 
       546 
549 
     | 
    
         
             
                    if bias:
         
     | 
| 
         @@ -11,7 +11,6 @@ from sglang.srt.lora.triton_ops import ( 
     | 
|
| 
       11 
11 
     | 
    
         
             
            )
         
     | 
| 
       12 
12 
     | 
    
         
             
            from sglang.srt.lora.utils import LoRABatchInfo
         
     | 
| 
       13 
13 
     | 
    
         
             
            from sglang.srt.model_executor.forward_batch_info import ForwardBatch
         
     | 
| 
       14 
     | 
    
         
            -
            from sglang.srt.server_args import ServerArgs
         
     | 
| 
       15 
14 
     | 
    
         | 
| 
       16 
15 
     | 
    
         | 
| 
       17 
16 
     | 
    
         
             
            class TritonLoRABackend(BaseLoRABackend):
         
     | 
| 
         @@ -0,0 +1,139 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # Copyright 2023-2024 SGLang Team
         
     | 
| 
      
 2 
     | 
    
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 
      
 3 
     | 
    
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 
      
 4 
     | 
    
         
            +
            # You may obtain a copy of the License at
         
     | 
| 
      
 5 
     | 
    
         
            +
            #
         
     | 
| 
      
 6 
     | 
    
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 
      
 7 
     | 
    
         
            +
            #
         
     | 
| 
      
 8 
     | 
    
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 
      
 9 
     | 
    
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 
      
 10 
     | 
    
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 
      
 11 
     | 
    
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 
      
 12 
     | 
    
         
            +
            # limitations under the License.
         
     | 
| 
      
 13 
     | 
    
         
            +
            # ==============================================================================
         
     | 
| 
      
 14 
     | 
    
         
            +
             
     | 
| 
      
 15 
     | 
    
         
            +
            """
         
     | 
| 
      
 16 
     | 
    
         
            +
            Eviction policies for LoRA adapter memory management.
         
     | 
| 
      
 17 
     | 
    
         
            +
            """
         
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
      
 19 
     | 
    
         
            +
            import logging
         
     | 
| 
      
 20 
     | 
    
         
            +
            import time
         
     | 
| 
      
 21 
     | 
    
         
            +
            from abc import ABC, abstractmethod
         
     | 
| 
      
 22 
     | 
    
         
            +
            from collections import OrderedDict
         
     | 
| 
      
 23 
     | 
    
         
            +
            from typing import Optional, Set
         
     | 
| 
      
 24 
     | 
    
         
            +
             
     | 
| 
      
 25 
     | 
    
         
            +
            logger = logging.getLogger(__name__)
         
     | 
| 
      
 26 
     | 
    
         
            +
             
     | 
| 
      
 27 
     | 
    
         
            +
             
     | 
| 
      
 28 
     | 
    
         
            +
            class EvictionPolicy(ABC):
         
     | 
| 
      
 29 
     | 
    
         
            +
                """Abstract base class for LoRA adapter eviction policies."""
         
     | 
| 
      
 30 
     | 
    
         
            +
             
     | 
| 
      
 31 
     | 
    
         
            +
                @abstractmethod
         
     | 
| 
      
 32 
     | 
    
         
            +
                def mark_used(self, uid: Optional[str]) -> None:
         
     | 
| 
      
 33 
     | 
    
         
            +
                    """Marks an adapter as used."""
         
     | 
| 
      
 34 
     | 
    
         
            +
                    pass
         
     | 
| 
      
 35 
     | 
    
         
            +
             
     | 
| 
      
 36 
     | 
    
         
            +
                @abstractmethod
         
     | 
| 
      
 37 
     | 
    
         
            +
                def select_victim(self, candidates: Set[Optional[str]]) -> Optional[str]:
         
     | 
| 
      
 38 
     | 
    
         
            +
                    """Selects an adapter to evict from candidates."""
         
     | 
| 
      
 39 
     | 
    
         
            +
                    pass
         
     | 
| 
      
 40 
     | 
    
         
            +
             
     | 
| 
      
 41 
     | 
    
         
            +
                @abstractmethod
         
     | 
| 
      
 42 
     | 
    
         
            +
                def remove(self, uid: Optional[str]) -> None:
         
     | 
| 
      
 43 
     | 
    
         
            +
                    """Removes an adapter from the policy's tracking."""
         
     | 
| 
      
 44 
     | 
    
         
            +
                    pass
         
     | 
| 
      
 45 
     | 
    
         
            +
             
     | 
| 
      
 46 
     | 
    
         
            +
             
     | 
| 
      
 47 
     | 
    
         
            +
            class LRUEvictionPolicy(EvictionPolicy):
         
     | 
| 
      
 48 
     | 
    
         
            +
                """LRU eviction policy - evicts the least recently used adapter."""
         
     | 
| 
      
 49 
     | 
    
         
            +
             
     | 
| 
      
 50 
     | 
    
         
            +
                def __init__(self):
         
     | 
| 
      
 51 
     | 
    
         
            +
                    self.access_order = OrderedDict()  # key=uid, value=last_access_time
         
     | 
| 
      
 52 
     | 
    
         
            +
                    self.total_accesses = 0
         
     | 
| 
      
 53 
     | 
    
         
            +
                    self.eviction_count = 0
         
     | 
| 
      
 54 
     | 
    
         
            +
             
     | 
| 
      
 55 
     | 
    
         
            +
                def mark_used(self, uid: Optional[str]) -> None:
         
     | 
| 
      
 56 
     | 
    
         
            +
                    if uid is not None:
         
     | 
| 
      
 57 
     | 
    
         
            +
                        current_time = time.monotonic()
         
     | 
| 
      
 58 
     | 
    
         
            +
                        # Remove and re-add to move to end (most recent)
         
     | 
| 
      
 59 
     | 
    
         
            +
                        self.access_order.pop(uid, None)
         
     | 
| 
      
 60 
     | 
    
         
            +
                        self.access_order[uid] = current_time
         
     | 
| 
      
 61 
     | 
    
         
            +
                        self.total_accesses += 1
         
     | 
| 
      
 62 
     | 
    
         
            +
                        logger.debug(f"LoRA {uid} marked as used at {current_time}")
         
     | 
| 
      
 63 
     | 
    
         
            +
             
     | 
| 
      
 64 
     | 
    
         
            +
                def select_victim(self, candidates: Set[Optional[str]]) -> Optional[str]:
         
     | 
| 
      
 65 
     | 
    
         
            +
                    """Select the least recently used adapter from candidates."""
         
     | 
| 
      
 66 
     | 
    
         
            +
                    # Base model (currently None, will be replaced with special UID in future)
         
     | 
| 
      
 67 
     | 
    
         
            +
                    # always has lowest priority - evict it first if available
         
     | 
| 
      
 68 
     | 
    
         
            +
                    BASE_MODEL_UID = None  # TODO: Replace with special UID constant
         
     | 
| 
      
 69 
     | 
    
         
            +
                    if BASE_MODEL_UID in candidates:
         
     | 
| 
      
 70 
     | 
    
         
            +
                        logger.debug(f"Selected base model for eviction (LRU)")
         
     | 
| 
      
 71 
     | 
    
         
            +
                        self.eviction_count += 1
         
     | 
| 
      
 72 
     | 
    
         
            +
                        return BASE_MODEL_UID
         
     | 
| 
      
 73 
     | 
    
         
            +
             
     | 
| 
      
 74 
     | 
    
         
            +
                    # Iterate through access_order (oldest first) to find LRU victim
         
     | 
| 
      
 75 
     | 
    
         
            +
                    for uid in list(self.access_order.keys()):
         
     | 
| 
      
 76 
     | 
    
         
            +
                        if uid in candidates:
         
     | 
| 
      
 77 
     | 
    
         
            +
                            logger.debug(f"Selected LoRA {uid} for eviction (LRU)")
         
     | 
| 
      
 78 
     | 
    
         
            +
                            self.eviction_count += 1
         
     | 
| 
      
 79 
     | 
    
         
            +
                            return uid
         
     | 
| 
      
 80 
     | 
    
         
            +
             
     | 
| 
      
 81 
     | 
    
         
            +
                    # Should never reach here if candidates is non-empty
         
     | 
| 
      
 82 
     | 
    
         
            +
                    assert False, f"Failed to select LRU victim from candidates: {candidates}"
         
     | 
| 
      
 83 
     | 
    
         
            +
             
     | 
| 
      
 84 
     | 
    
         
            +
                def remove(self, uid: Optional[str]) -> None:
         
     | 
| 
      
 85 
     | 
    
         
            +
                    if uid is not None:
         
     | 
| 
      
 86 
     | 
    
         
            +
                        self.access_order.pop(uid, None)
         
     | 
| 
      
 87 
     | 
    
         
            +
                        logger.debug(f"Removed LoRA {uid} from LRU tracking")
         
     | 
| 
      
 88 
     | 
    
         
            +
             
     | 
| 
      
 89 
     | 
    
         
            +
             
     | 
| 
      
 90 
     | 
    
         
            +
            class FIFOEvictionPolicy(EvictionPolicy):
         
     | 
| 
      
 91 
     | 
    
         
            +
                """FIFO eviction policy - for backward compatibility."""
         
     | 
| 
      
 92 
     | 
    
         
            +
             
     | 
| 
      
 93 
     | 
    
         
            +
                def __init__(self):
         
     | 
| 
      
 94 
     | 
    
         
            +
                    self.insertion_order = (
         
     | 
| 
      
 95 
     | 
    
         
            +
                        OrderedDict()
         
     | 
| 
      
 96 
     | 
    
         
            +
                    )  # key=uid, OrderedDict maintains insertion order
         
     | 
| 
      
 97 
     | 
    
         
            +
                    self.eviction_count = 0
         
     | 
| 
      
 98 
     | 
    
         
            +
             
     | 
| 
      
 99 
     | 
    
         
            +
                def mark_used(self, uid: Optional[str]) -> None:
         
     | 
| 
      
 100 
     | 
    
         
            +
                    """For FIFO, we only track insertion order (not access time)."""
         
     | 
| 
      
 101 
     | 
    
         
            +
                    if uid is not None and uid not in self.insertion_order:
         
     | 
| 
      
 102 
     | 
    
         
            +
                        self.insertion_order[uid] = (
         
     | 
| 
      
 103 
     | 
    
         
            +
                            True  # Value unused, OrderedDict tracks insertion order
         
     | 
| 
      
 104 
     | 
    
         
            +
                        )
         
     | 
| 
      
 105 
     | 
    
         
            +
             
     | 
| 
      
 106 
     | 
    
         
            +
                def select_victim(self, candidates: Set[Optional[str]]) -> Optional[str]:
         
     | 
| 
      
 107 
     | 
    
         
            +
                    """Select the first inserted adapter from candidates."""
         
     | 
| 
      
 108 
     | 
    
         
            +
                    # Base model (currently None, will be replaced with special UID in future)
         
     | 
| 
      
 109 
     | 
    
         
            +
                    # always has lowest priority - evict it first if available
         
     | 
| 
      
 110 
     | 
    
         
            +
                    BASE_MODEL_UID = None  # TODO: Replace with special UID constant
         
     | 
| 
      
 111 
     | 
    
         
            +
                    if BASE_MODEL_UID in candidates:
         
     | 
| 
      
 112 
     | 
    
         
            +
                        logger.debug(f"Selected base model for eviction (FIFO)")
         
     | 
| 
      
 113 
     | 
    
         
            +
                        self.eviction_count += 1
         
     | 
| 
      
 114 
     | 
    
         
            +
                        return BASE_MODEL_UID
         
     | 
| 
      
 115 
     | 
    
         
            +
             
     | 
| 
      
 116 
     | 
    
         
            +
                    # Iterate through insertion_order (oldest first) to find FIFO victim
         
     | 
| 
      
 117 
     | 
    
         
            +
                    for uid in list(self.insertion_order.keys()):
         
     | 
| 
      
 118 
     | 
    
         
            +
                        if uid in candidates:
         
     | 
| 
      
 119 
     | 
    
         
            +
                            logger.debug(f"Selected LoRA {uid} for eviction (FIFO)")
         
     | 
| 
      
 120 
     | 
    
         
            +
                            self.eviction_count += 1
         
     | 
| 
      
 121 
     | 
    
         
            +
                            return uid
         
     | 
| 
      
 122 
     | 
    
         
            +
             
     | 
| 
      
 123 
     | 
    
         
            +
                    # Should never reach here if candidates is non-empty
         
     | 
| 
      
 124 
     | 
    
         
            +
                    assert False, f"Failed to select FIFO victim from candidates: {candidates}"
         
     | 
| 
      
 125 
     | 
    
         
            +
             
     | 
| 
      
 126 
     | 
    
         
            +
                def remove(self, uid: Optional[str]) -> None:
         
     | 
| 
      
 127 
     | 
    
         
            +
                    if uid is not None:
         
     | 
| 
      
 128 
     | 
    
         
            +
                        self.insertion_order.pop(uid, None)
         
     | 
| 
      
 129 
     | 
    
         
            +
             
     | 
| 
      
 130 
     | 
    
         
            +
             
     | 
| 
      
 131 
     | 
    
         
            +
            def get_eviction_policy(policy_name: str) -> EvictionPolicy:
         
     | 
| 
      
 132 
     | 
    
         
            +
                """Factory function to create eviction policy instances."""
         
     | 
| 
      
 133 
     | 
    
         
            +
                policies = {
         
     | 
| 
      
 134 
     | 
    
         
            +
                    "fifo": FIFOEvictionPolicy,
         
     | 
| 
      
 135 
     | 
    
         
            +
                    "lru": LRUEvictionPolicy,
         
     | 
| 
      
 136 
     | 
    
         
            +
                }
         
     | 
| 
      
 137 
     | 
    
         
            +
                if policy_name not in policies:
         
     | 
| 
      
 138 
     | 
    
         
            +
                    raise ValueError(f"Unknown eviction policy: {policy_name}")
         
     | 
| 
      
 139 
     | 
    
         
            +
                return policies[policy_name]()
         
     | 
    
        sglang/srt/lora/lora_manager.py
    CHANGED
    
    | 
         @@ -16,7 +16,7 @@ 
     | 
|
| 
       16 
16 
     | 
    
         
             
            # and "Punica: Multi-Tenant LoRA Serving"
         
     | 
| 
       17 
17 
     | 
    
         | 
| 
       18 
18 
     | 
    
         
             
            import logging
         
     | 
| 
       19 
     | 
    
         
            -
            from typing import Dict, Iterable, List, Optional 
     | 
| 
      
 19 
     | 
    
         
            +
            from typing import Dict, Iterable, List, Optional
         
     | 
| 
       20 
20 
     | 
    
         | 
| 
       21 
21 
     | 
    
         
             
            import torch
         
     | 
| 
       22 
22 
     | 
    
         | 
| 
         @@ -68,6 +68,9 @@ class LoRAManager: 
     | 
|
| 
       68 
68 
     | 
    
         
             
                    self.tp_size: int = tp_size
         
     | 
| 
       69 
69 
     | 
    
         
             
                    self.tp_rank: int = tp_rank
         
     | 
| 
       70 
70 
     | 
    
         | 
| 
      
 71 
     | 
    
         
            +
                    # Store eviction policy from server args
         
     | 
| 
      
 72 
     | 
    
         
            +
                    self.eviction_policy = server_args.lora_eviction_policy
         
     | 
| 
      
 73 
     | 
    
         
            +
             
     | 
| 
       71 
74 
     | 
    
         
             
                    # LoRA backend for running sgemm kernels
         
     | 
| 
       72 
75 
     | 
    
         
             
                    logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
         
     | 
| 
       73 
76 
     | 
    
         
             
                    backend_type = get_backend_from_name(lora_backend)
         
     | 
| 
         @@ -131,6 +134,16 @@ class LoRAManager: 
     | 
|
| 
       131 
134 
     | 
    
         
             
                        lora_ref.lora_id not in self.loras
         
     | 
| 
       132 
135 
     | 
    
         
             
                    ), f"LoRA adapter with ID {lora_ref.lora_id} is already loaded. This should have been verified before request is sent to the backend."
         
     | 
| 
       133 
136 
     | 
    
         | 
| 
      
 137 
     | 
    
         
            +
                    if lora_ref.pinned and self.num_pinned_loras >= self.max_loras_per_batch - 1:
         
     | 
| 
      
 138 
     | 
    
         
            +
                        return self.create_lora_update_result(
         
     | 
| 
      
 139 
     | 
    
         
            +
                            success=False,
         
     | 
| 
      
 140 
     | 
    
         
            +
                            error_message=(
         
     | 
| 
      
 141 
     | 
    
         
            +
                                f"Already have {self.num_pinned_loras} pinned adapters, "
         
     | 
| 
      
 142 
     | 
    
         
            +
                                f"max allowed is {self.max_loras_per_batch - 1} (reserving 1 slot for dynamic use). "
         
     | 
| 
      
 143 
     | 
    
         
            +
                                f"Please unpin some adapters or increase max_loras_per_batch."
         
     | 
| 
      
 144 
     | 
    
         
            +
                            ),
         
     | 
| 
      
 145 
     | 
    
         
            +
                        )
         
     | 
| 
      
 146 
     | 
    
         
            +
             
     | 
| 
       134 
147 
     | 
    
         
             
                    try:
         
     | 
| 
       135 
148 
     | 
    
         
             
                        # load configs
         
     | 
| 
       136 
149 
     | 
    
         
             
                        new_adapter = LoRAConfig(lora_ref.lora_path)
         
     | 
| 
         @@ -156,6 +169,15 @@ class LoRAManager: 
     | 
|
| 
       156 
169 
     | 
    
         
             
                    Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
         
     | 
| 
       157 
170 
     | 
    
         
             
                    """
         
     | 
| 
       158 
171 
     | 
    
         | 
| 
      
 172 
     | 
    
         
            +
                    # Check if this LoRA adapter is already loaded
         
     | 
| 
      
 173 
     | 
    
         
            +
                    if any(
         
     | 
| 
      
 174 
     | 
    
         
            +
                        lora_ref.lora_name == existing_lora_ref.lora_name
         
     | 
| 
      
 175 
     | 
    
         
            +
                        for existing_lora_ref in self.lora_refs.values()
         
     | 
| 
      
 176 
     | 
    
         
            +
                    ):
         
     | 
| 
      
 177 
     | 
    
         
            +
                        raise ValueError(
         
     | 
| 
      
 178 
     | 
    
         
            +
                            f"Failed to load LoRA adapter {lora_ref.lora_name} because it is already loaded"
         
     | 
| 
      
 179 
     | 
    
         
            +
                        )
         
     | 
| 
      
 180 
     | 
    
         
            +
             
     | 
| 
       159 
181 
     | 
    
         
             
                    # Check if the LoRA adapter shape is compatible with the current LoRA memory pool configuration.
         
     | 
| 
       160 
182 
     | 
    
         
             
                    memory_pool = getattr(self, "memory_pool", None)
         
     | 
| 
       161 
183 
     | 
    
         
             
                    incompatible = memory_pool and not memory_pool.can_support(lora_config)
         
     | 
| 
         @@ -411,6 +433,7 @@ class LoRAManager: 
     | 
|
| 
       411 
433 
     | 
    
         
             
                        max_lora_rank=self.max_lora_rank,
         
     | 
| 
       412 
434 
     | 
    
         
             
                        target_modules=self.target_modules,
         
     | 
| 
       413 
435 
     | 
    
         
             
                        base_model=self.base_model,
         
     | 
| 
      
 436 
     | 
    
         
            +
                        eviction_policy=self.eviction_policy,
         
     | 
| 
       414 
437 
     | 
    
         
             
                    )
         
     | 
| 
       415 
438 
     | 
    
         | 
| 
       416 
439 
     | 
    
         
             
                def set_lora_module(self, module_name, module):
         
     | 
| 
         @@ -418,10 +441,6 @@ class LoRAManager: 
     | 
|
| 
       418 
441 
     | 
    
         
             
                    replace_submodule(self.base_model, module_name, lora_module)
         
     | 
| 
       419 
442 
     | 
    
         
             
                    return lora_module
         
     | 
| 
       420 
443 
     | 
    
         | 
| 
       421 
     | 
    
         
            -
                def should_skip_lora_for_vision_model(self, module_name):
         
     | 
| 
       422 
     | 
    
         
            -
                    # TODO: support different vision models
         
     | 
| 
       423 
     | 
    
         
            -
                    return module_name.find("vision_model.model") != -1
         
     | 
| 
       424 
     | 
    
         
            -
             
     | 
| 
       425 
444 
     | 
    
         
             
                def init_lora_modules(self):
         
     | 
| 
       426 
445 
     | 
    
         
             
                    # Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
         
     | 
| 
       427 
446 
     | 
    
         
             
                    self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [
         
     | 
| 
         @@ -439,10 +458,6 @@ class LoRAManager: 
     | 
|
| 
       439 
458 
     | 
    
         
             
                        ) and not self.base_model.should_apply_lora(module_name):
         
     | 
| 
       440 
459 
     | 
    
         
             
                            continue
         
     | 
| 
       441 
460 
     | 
    
         | 
| 
       442 
     | 
    
         
            -
                        # Skip vision model
         
     | 
| 
       443 
     | 
    
         
            -
                        if self.should_skip_lora_for_vision_model(module_name):
         
     | 
| 
       444 
     | 
    
         
            -
                            continue
         
     | 
| 
       445 
     | 
    
         
            -
             
     | 
| 
       446 
461 
     | 
    
         
             
                        # The module should be converted if it is included in target_names
         
     | 
| 
       447 
462 
     | 
    
         
             
                        if module_name.split(".")[-1] in self.target_modules:
         
     | 
| 
       448 
463 
     | 
    
         
             
                            layer_id = get_layer_id(module_name)
         
     | 
    
        sglang/srt/lora/lora_registry.py
    CHANGED
    
    | 
         @@ -18,8 +18,8 @@ from dataclasses import dataclass, field, fields 
     | 
|
| 
       18 
18 
     | 
    
         
             
            from typing import Dict, List, Optional, Union
         
     | 
| 
       19 
19 
     | 
    
         
             
            from uuid import uuid4
         
     | 
| 
       20 
20 
     | 
    
         | 
| 
       21 
     | 
    
         
            -
            from sglang.srt.aio_rwlock import RWLock
         
     | 
| 
       22 
21 
     | 
    
         
             
            from sglang.srt.utils import ConcurrentCounter
         
     | 
| 
      
 22 
     | 
    
         
            +
            from sglang.srt.utils.aio_rwlock import RWLock
         
     | 
| 
       23 
23 
     | 
    
         | 
| 
       24 
24 
     | 
    
         | 
| 
       25 
25 
     | 
    
         
             
            @dataclass(frozen=True)
         
     | 
    
        sglang/srt/lora/mem_pool.py
    CHANGED
    
    | 
         @@ -4,6 +4,7 @@ from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union 
     | 
|
| 
       4 
4 
     | 
    
         
             
            import torch
         
     | 
| 
       5 
5 
     | 
    
         | 
| 
       6 
6 
     | 
    
         
             
            from sglang.srt.distributed import divide
         
     | 
| 
      
 7 
     | 
    
         
            +
            from sglang.srt.lora.eviction_policy import get_eviction_policy
         
     | 
| 
       7 
8 
     | 
    
         
             
            from sglang.srt.lora.layers import BaseLayerWithLoRA
         
     | 
| 
       8 
9 
     | 
    
         
             
            from sglang.srt.lora.lora import LoRAAdapter
         
     | 
| 
       9 
10 
     | 
    
         
             
            from sglang.srt.lora.lora_config import LoRAConfig
         
     | 
| 
         @@ -54,6 +55,7 @@ class LoRAMemoryPool: 
     | 
|
| 
       54 
55 
     | 
    
         
             
                    max_lora_rank: int,
         
     | 
| 
       55 
56 
     | 
    
         
             
                    target_modules: Set[str],
         
     | 
| 
       56 
57 
     | 
    
         
             
                    base_model: torch.nn.Module,
         
     | 
| 
      
 58 
     | 
    
         
            +
                    eviction_policy: str,
         
     | 
| 
       57 
59 
     | 
    
         
             
                ):
         
     | 
| 
       58 
60 
     | 
    
         
             
                    self.base_hf_config: AutoConfig = base_hf_config
         
     | 
| 
       59 
61 
     | 
    
         
             
                    self.num_layer: int = base_hf_config.num_hidden_layers
         
     | 
| 
         @@ -64,6 +66,9 @@ class LoRAMemoryPool: 
     | 
|
| 
       64 
66 
     | 
    
         
             
                    self.max_lora_rank: int = max_lora_rank
         
     | 
| 
       65 
67 
     | 
    
         
             
                    self.target_modules: Set[str] = target_modules
         
     | 
| 
       66 
68 
     | 
    
         | 
| 
      
 69 
     | 
    
         
            +
                    # Initialize eviction policy
         
     | 
| 
      
 70 
     | 
    
         
            +
                    self.eviction_policy = get_eviction_policy(eviction_policy)
         
     | 
| 
      
 71 
     | 
    
         
            +
             
     | 
| 
       67 
72 
     | 
    
         
             
                    # Both A_buffer and B_buffer maps lora weight names to its buffer space.
         
     | 
| 
       68 
73 
     | 
    
         
             
                    # A_buffer contains num_layer number of row-major tensors with shape
         
     | 
| 
       69 
74 
     | 
    
         
             
                    #   (max_loras_per_batch, stacked_num * max_lora_dim, input_dim)
         
     | 
| 
         @@ -189,31 +194,50 @@ class LoRAMemoryPool: 
     | 
|
| 
       189 
194 
     | 
    
         
             
                    lora_refs: Dict[str, LoRARef],
         
     | 
| 
       190 
195 
     | 
    
         
             
                ):
         
     | 
| 
       191 
196 
     | 
    
         
             
                    def get_available_buffer_slot():
         
     | 
| 
      
 197 
     | 
    
         
            +
                        # 1. Prioritize empty slots
         
     | 
| 
       192 
198 
     | 
    
         
             
                        for buffer_id in range(self.max_loras_per_batch):
         
     | 
| 
       193 
     | 
    
         
            -
                            # Prioritize empty slots
         
     | 
| 
       194 
199 
     | 
    
         
             
                            if self.buffer_id_to_uid[buffer_id] == EMPTY_SLOT:
         
     | 
| 
       195 
200 
     | 
    
         
             
                                return buffer_id
         
     | 
| 
       196 
201 
     | 
    
         | 
| 
      
 202 
     | 
    
         
            +
                        # 2. Memory pool is full, need to evict using policy
         
     | 
| 
      
 203 
     | 
    
         
            +
                        candidates = set()
         
     | 
| 
      
 204 
     | 
    
         
            +
             
     | 
| 
       197 
205 
     | 
    
         
             
                        for buffer_id in range(self.max_loras_per_batch):
         
     | 
| 
       198 
206 
     | 
    
         
             
                            uid = self.buffer_id_to_uid[buffer_id]
         
     | 
| 
       199 
207 
     | 
    
         | 
| 
       200 
     | 
    
         
            -
                            #  
     | 
| 
       201 
     | 
    
         
            -
                             
     | 
| 
       202 
     | 
    
         
            -
             
     | 
| 
       203 
     | 
    
         
            -
                                 
     | 
| 
       204 
     | 
    
         
            -
             
     | 
| 
       205 
     | 
    
         
            -
             
     | 
| 
       206 
     | 
    
         
            -
             
     | 
| 
       207 
     | 
    
         
            -
             
     | 
| 
       208 
     | 
    
         
            -
             
     | 
| 
       209 
     | 
    
         
            -
             
     | 
| 
       210 
     | 
    
         
            -
             
     | 
| 
       211 
     | 
    
         
            -
                                self.buffer_id_to_uid[buffer_id] = EMPTY_SLOT
         
     | 
| 
       212 
     | 
    
         
            -
                                return buffer_id
         
     | 
| 
      
 208 
     | 
    
         
            +
                            # Skip if this adapter is needed by current batch
         
     | 
| 
      
 209 
     | 
    
         
            +
                            # TODO (lifuhuang): we might consider supporting pinning base model (uid == None) in the future.
         
     | 
| 
      
 210 
     | 
    
         
            +
                            if uid in cur_uids:
         
     | 
| 
      
 211 
     | 
    
         
            +
                                continue
         
     | 
| 
      
 212 
     | 
    
         
            +
             
     | 
| 
      
 213 
     | 
    
         
            +
                            # Skip if this adapter is pinned (base model cannot be pinned, so can be evicted)
         
     | 
| 
      
 214 
     | 
    
         
            +
                            if uid is not None:
         
     | 
| 
      
 215 
     | 
    
         
            +
                                lora_ref = lora_refs.get(uid)
         
     | 
| 
      
 216 
     | 
    
         
            +
                                if lora_ref and lora_ref.pinned:
         
     | 
| 
      
 217 
     | 
    
         
            +
                                    continue
         
     | 
| 
      
 218 
     | 
    
         
            +
                            candidates.add(uid)
         
     | 
| 
       213 
219 
     | 
    
         | 
| 
       214 
     | 
    
         
            -
                         
     | 
| 
       215 
     | 
    
         
            -
                             
     | 
| 
      
 220 
     | 
    
         
            +
                        if not candidates:
         
     | 
| 
      
 221 
     | 
    
         
            +
                            raise ValueError(
         
     | 
| 
      
 222 
     | 
    
         
            +
                                "No available buffer slots found. Please ensure the number of active (pinned) loras is less than max_loras_per_batch."
         
     | 
| 
      
 223 
     | 
    
         
            +
                            )
         
     | 
| 
      
 224 
     | 
    
         
            +
             
     | 
| 
      
 225 
     | 
    
         
            +
                        # Select victim using eviction policy
         
     | 
| 
      
 226 
     | 
    
         
            +
                        victim_uid = self.eviction_policy.select_victim(candidates)
         
     | 
| 
      
 227 
     | 
    
         
            +
             
     | 
| 
      
 228 
     | 
    
         
            +
                        # Evict the selected victim
         
     | 
| 
      
 229 
     | 
    
         
            +
                        victim_buffer_id = self.uid_to_buffer_id[victim_uid]
         
     | 
| 
      
 230 
     | 
    
         
            +
                        self.uid_to_buffer_id.pop(victim_uid)
         
     | 
| 
      
 231 
     | 
    
         
            +
                        self.eviction_policy.remove(victim_uid)
         
     | 
| 
      
 232 
     | 
    
         
            +
                        self.buffer_id_to_uid[victim_buffer_id] = EMPTY_SLOT
         
     | 
| 
      
 233 
     | 
    
         
            +
                        logger.debug(
         
     | 
| 
      
 234 
     | 
    
         
            +
                            f"Evicting LoRA {victim_uid} from buffer slot {victim_buffer_id}."
         
     | 
| 
       216 
235 
     | 
    
         
             
                        )
         
     | 
| 
      
 236 
     | 
    
         
            +
                        return victim_buffer_id
         
     | 
| 
      
 237 
     | 
    
         
            +
             
     | 
| 
      
 238 
     | 
    
         
            +
                    # Mark all adapters in current batch as used (for LRU tracking)
         
     | 
| 
      
 239 
     | 
    
         
            +
                    for uid in cur_uids:
         
     | 
| 
      
 240 
     | 
    
         
            +
                        self.eviction_policy.mark_used(uid)
         
     | 
| 
       217 
241 
     | 
    
         | 
| 
       218 
242 
     | 
    
         
             
                    for uid in cur_uids:
         
     | 
| 
       219 
243 
     | 
    
         
             
                        if uid not in self.uid_to_buffer_id:
         
     | 
| 
         @@ -9,7 +9,7 @@ from sglang.srt.utils import cached_triton_kernel 
     | 
|
| 
       9 
9 
     | 
    
         | 
| 
       10 
10 
     | 
    
         | 
| 
       11 
11 
     | 
    
         
             
            @cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))
         
     | 
| 
       12 
     | 
    
         
            -
            @triton.jit
         
     | 
| 
      
 12 
     | 
    
         
            +
            @triton.jit(do_not_specialize=["num_segs"])
         
     | 
| 
       13 
13 
     | 
    
         
             
            def _chunked_lora_expand_kernel(
         
     | 
| 
       14 
14 
     | 
    
         
             
                # Pointers to matrices
         
     | 
| 
       15 
15 
     | 
    
         
             
                x,
         
     | 
| 
         @@ -6,8 +6,10 @@ from sglang.srt.lora.utils import LoRABatchInfo 
     | 
|
| 
       6 
6 
     | 
    
         
             
            from sglang.srt.utils import cached_triton_kernel
         
     | 
| 
       7 
7 
     | 
    
         | 
| 
       8 
8 
     | 
    
         | 
| 
       9 
     | 
    
         
            -
            @cached_triton_kernel( 
     | 
| 
       10 
     | 
    
         
            -
             
     | 
| 
      
 9 
     | 
    
         
            +
            @cached_triton_kernel(
         
     | 
| 
      
 10 
     | 
    
         
            +
                lambda _, kwargs: (kwargs["K"], kwargs["NUM_SLICES"], kwargs["BLOCK_M"])
         
     | 
| 
      
 11 
     | 
    
         
            +
            )
         
     | 
| 
      
 12 
     | 
    
         
            +
            @triton.jit(do_not_specialize=["num_segs"])
         
     | 
| 
       11 
13 
     | 
    
         
             
            def _chunked_lora_shrink_kernel(
         
     | 
| 
       12 
14 
     | 
    
         
             
                # Pointers to matrices
         
     | 
| 
       13 
15 
     | 
    
         
             
                x,
         
     |