sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +330 -156
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +8 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +134 -23
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +70 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +66 -66
- sglang/srt/entrypoints/grpc_server.py +431 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +120 -8
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +42 -4
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +18 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +4 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +12 -8
- sglang/srt/layers/attention/flashinfer_backend.py +248 -21
- sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +11 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +45 -15
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +147 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
- sglang/srt/layers/moe/ep_moe/layer.py +119 -397
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +17 -1
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +5 -30
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +673 -16
- sglang/srt/layers/sampler.py +36 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +66 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +399 -499
- sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +378 -90
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +435 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +28 -23
- sglang/srt/model_executor/model_runner.py +379 -139
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +273 -98
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +14 -37
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +5 -5
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3_moe.py +19 -35
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +6 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +577 -73
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +38 -28
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +192 -47
- sglang/srt/utils/hf_transformers_utils.py +40 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +232 -99
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
sglang/srt/managers/tp_worker.py
CHANGED
|
@@ -15,14 +15,13 @@
|
|
|
15
15
|
from __future__ import annotations
|
|
16
16
|
|
|
17
17
|
import logging
|
|
18
|
-
import
|
|
19
|
-
from typing import TYPE_CHECKING, Optional
|
|
18
|
+
from abc import ABC, abstractmethod
|
|
19
|
+
from typing import TYPE_CHECKING, Optional
|
|
20
20
|
|
|
21
21
|
import torch
|
|
22
22
|
|
|
23
23
|
from sglang.srt.configs.model_config import ModelConfig
|
|
24
24
|
from sglang.srt.distributed import get_pp_group, get_world_group
|
|
25
|
-
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
|
26
25
|
from sglang.srt.managers.io_struct import (
|
|
27
26
|
DestroyWeightsUpdateGroupReqInput,
|
|
28
27
|
GetWeightsByNameReqInput,
|
|
@@ -33,16 +32,14 @@ from sglang.srt.managers.io_struct import (
|
|
|
33
32
|
UnloadLoRAAdapterReqInput,
|
|
34
33
|
UpdateWeightFromDiskReqInput,
|
|
35
34
|
UpdateWeightsFromDistributedReqInput,
|
|
35
|
+
UpdateWeightsFromIPCReqInput,
|
|
36
36
|
UpdateWeightsFromTensorReqInput,
|
|
37
37
|
)
|
|
38
|
-
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
|
38
|
+
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
|
39
|
+
from sglang.srt.managers.scheduler import GenerationBatchResult
|
|
39
40
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
|
40
41
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
|
41
|
-
from sglang.srt.model_executor.forward_batch_info import
|
|
42
|
-
ForwardBatch,
|
|
43
|
-
ForwardBatchOutput,
|
|
44
|
-
PPProxyTensors,
|
|
45
|
-
)
|
|
42
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
|
46
43
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
|
47
44
|
from sglang.srt.server_args import ServerArgs
|
|
48
45
|
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
|
|
@@ -59,7 +56,145 @@ if TYPE_CHECKING:
|
|
|
59
56
|
logger = logging.getLogger(__name__)
|
|
60
57
|
|
|
61
58
|
|
|
62
|
-
class
|
|
59
|
+
class BaseTpWorker(ABC):
|
|
60
|
+
@abstractmethod
|
|
61
|
+
def forward_batch_generation(self, forward_batch: ForwardBatch):
|
|
62
|
+
pass
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
@abstractmethod
|
|
66
|
+
def model_runner(self) -> ModelRunner:
|
|
67
|
+
pass
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def sliding_window_size(self) -> Optional[int]:
|
|
71
|
+
return self.model_runner.sliding_window_size
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def is_hybrid(self) -> bool:
|
|
75
|
+
return self.model_runner.is_hybrid is not None
|
|
76
|
+
|
|
77
|
+
def get_tokens_per_layer_info(self):
|
|
78
|
+
return (
|
|
79
|
+
self.model_runner.full_max_total_num_tokens,
|
|
80
|
+
self.model_runner.swa_max_total_num_tokens,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
def get_pad_input_ids_func(self):
|
|
84
|
+
return getattr(self.model_runner.model, "pad_input_ids", None)
|
|
85
|
+
|
|
86
|
+
def get_tp_group(self):
|
|
87
|
+
return self.model_runner.tp_group
|
|
88
|
+
|
|
89
|
+
def get_attention_tp_group(self):
|
|
90
|
+
return self.model_runner.attention_tp_group
|
|
91
|
+
|
|
92
|
+
def get_attention_tp_cpu_group(self):
|
|
93
|
+
return getattr(self.model_runner.attention_tp_group, "cpu_group", None)
|
|
94
|
+
|
|
95
|
+
def get_memory_pool(self):
|
|
96
|
+
return (
|
|
97
|
+
self.model_runner.req_to_token_pool,
|
|
98
|
+
self.model_runner.token_to_kv_pool_allocator,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
|
102
|
+
success, message = self.model_runner.update_weights_from_disk(
|
|
103
|
+
recv_req.model_path, recv_req.load_format
|
|
104
|
+
)
|
|
105
|
+
return success, message
|
|
106
|
+
|
|
107
|
+
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
|
108
|
+
success, message = self.model_runner.init_weights_update_group(
|
|
109
|
+
recv_req.master_address,
|
|
110
|
+
recv_req.master_port,
|
|
111
|
+
recv_req.rank_offset,
|
|
112
|
+
recv_req.world_size,
|
|
113
|
+
recv_req.group_name,
|
|
114
|
+
recv_req.backend,
|
|
115
|
+
)
|
|
116
|
+
return success, message
|
|
117
|
+
|
|
118
|
+
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
|
|
119
|
+
success, message = self.model_runner.destroy_weights_update_group(
|
|
120
|
+
recv_req.group_name,
|
|
121
|
+
)
|
|
122
|
+
return success, message
|
|
123
|
+
|
|
124
|
+
def init_weights_send_group_for_remote_instance(
|
|
125
|
+
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
|
126
|
+
):
|
|
127
|
+
success, message = (
|
|
128
|
+
self.model_runner.init_weights_send_group_for_remote_instance(
|
|
129
|
+
recv_req.master_address,
|
|
130
|
+
recv_req.ports,
|
|
131
|
+
recv_req.group_rank,
|
|
132
|
+
recv_req.world_size,
|
|
133
|
+
recv_req.group_name,
|
|
134
|
+
recv_req.backend,
|
|
135
|
+
)
|
|
136
|
+
)
|
|
137
|
+
return success, message
|
|
138
|
+
|
|
139
|
+
def send_weights_to_remote_instance(
|
|
140
|
+
self, recv_req: SendWeightsToRemoteInstanceReqInput
|
|
141
|
+
):
|
|
142
|
+
success, message = self.model_runner.send_weights_to_remote_instance(
|
|
143
|
+
recv_req.master_address,
|
|
144
|
+
recv_req.ports,
|
|
145
|
+
recv_req.group_name,
|
|
146
|
+
)
|
|
147
|
+
return success, message
|
|
148
|
+
|
|
149
|
+
def update_weights_from_distributed(
|
|
150
|
+
self, recv_req: UpdateWeightsFromDistributedReqInput
|
|
151
|
+
):
|
|
152
|
+
success, message = self.model_runner.update_weights_from_distributed(
|
|
153
|
+
recv_req.names, recv_req.dtypes, recv_req.shapes, recv_req.group_name
|
|
154
|
+
)
|
|
155
|
+
return success, message
|
|
156
|
+
|
|
157
|
+
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
|
158
|
+
|
|
159
|
+
monkey_patch_torch_reductions()
|
|
160
|
+
success, message = self.model_runner.update_weights_from_tensor(
|
|
161
|
+
named_tensors=MultiprocessingSerializer.deserialize(
|
|
162
|
+
recv_req.serialized_named_tensors[self.tp_rank]
|
|
163
|
+
),
|
|
164
|
+
load_format=recv_req.load_format,
|
|
165
|
+
)
|
|
166
|
+
return success, message
|
|
167
|
+
|
|
168
|
+
def update_weights_from_ipc(self, recv_req: UpdateWeightsFromIPCReqInput):
|
|
169
|
+
"""Update weights from IPC for checkpoint-engine integration."""
|
|
170
|
+
success, message = self.model_runner.update_weights_from_ipc(recv_req)
|
|
171
|
+
return success, message
|
|
172
|
+
|
|
173
|
+
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
|
174
|
+
parameter = self.model_runner.get_weights_by_name(
|
|
175
|
+
recv_req.name, recv_req.truncate_size
|
|
176
|
+
)
|
|
177
|
+
return parameter
|
|
178
|
+
|
|
179
|
+
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
|
|
180
|
+
result = self.model_runner.load_lora_adapter(recv_req.to_ref())
|
|
181
|
+
return result
|
|
182
|
+
|
|
183
|
+
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
|
|
184
|
+
result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
|
|
185
|
+
return result
|
|
186
|
+
|
|
187
|
+
def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
|
|
188
|
+
return self.model_runner.lora_manager.validate_lora_batch(lora_ids)
|
|
189
|
+
|
|
190
|
+
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
|
191
|
+
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
|
192
|
+
logits_output, _ = self.model_runner.forward(forward_batch)
|
|
193
|
+
embeddings = logits_output.embeddings
|
|
194
|
+
return embeddings
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
class TpModelWorker(BaseTpWorker):
|
|
63
198
|
"""A tensor parallel model worker."""
|
|
64
199
|
|
|
65
200
|
def __init__(
|
|
@@ -97,7 +232,7 @@ class TpModelWorker:
|
|
|
97
232
|
is_draft_model=is_draft_worker,
|
|
98
233
|
)
|
|
99
234
|
|
|
100
|
-
self.
|
|
235
|
+
self._model_runner = ModelRunner(
|
|
101
236
|
model_config=self.model_config,
|
|
102
237
|
mem_fraction_static=server_args.mem_fraction_static,
|
|
103
238
|
gpu_id=gpu_id,
|
|
@@ -173,11 +308,13 @@ class TpModelWorker:
|
|
|
173
308
|
)[0]
|
|
174
309
|
set_random_seed(self.random_seed)
|
|
175
310
|
|
|
176
|
-
|
|
177
|
-
self.worker = self
|
|
178
|
-
|
|
311
|
+
self.enable_overlap = not server_args.disable_overlap_schedule
|
|
179
312
|
self.hicache_layer_transfer_counter = None
|
|
180
313
|
|
|
314
|
+
@property
|
|
315
|
+
def model_runner(self) -> ModelRunner:
|
|
316
|
+
return self._model_runner
|
|
317
|
+
|
|
181
318
|
def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
|
|
182
319
|
self.hicache_layer_transfer_counter = counter
|
|
183
320
|
|
|
@@ -195,54 +332,29 @@ class TpModelWorker:
|
|
|
195
332
|
self.max_req_input_len,
|
|
196
333
|
self.random_seed,
|
|
197
334
|
self.device,
|
|
198
|
-
global_server_args_dict,
|
|
199
335
|
self.model_runner.req_to_token_pool.size,
|
|
200
336
|
self.model_runner.req_to_token_pool.max_context_len,
|
|
201
337
|
self.model_runner.token_to_kv_pool.size,
|
|
202
338
|
)
|
|
203
339
|
|
|
204
|
-
@property
|
|
205
|
-
def sliding_window_size(self) -> Optional[int]:
|
|
206
|
-
return self.model_runner.sliding_window_size
|
|
207
|
-
|
|
208
|
-
@property
|
|
209
|
-
def is_hybrid(self) -> bool:
|
|
210
|
-
return self.model_runner.is_hybrid is not None
|
|
211
|
-
|
|
212
|
-
def get_tokens_per_layer_info(self):
|
|
213
|
-
return (
|
|
214
|
-
self.model_runner.full_max_total_num_tokens,
|
|
215
|
-
self.model_runner.swa_max_total_num_tokens,
|
|
216
|
-
)
|
|
217
|
-
|
|
218
|
-
def get_pad_input_ids_func(self):
|
|
219
|
-
return getattr(self.model_runner.model, "pad_input_ids", None)
|
|
220
|
-
|
|
221
|
-
def get_tp_group(self):
|
|
222
|
-
return self.model_runner.tp_group
|
|
223
|
-
|
|
224
|
-
def get_attention_tp_group(self):
|
|
225
|
-
return self.model_runner.attention_tp_group
|
|
226
|
-
|
|
227
|
-
def get_attention_tp_cpu_group(self):
|
|
228
|
-
return getattr(self.model_runner.attention_tp_group, "cpu_group", None)
|
|
229
|
-
|
|
230
|
-
def get_memory_pool(self):
|
|
231
|
-
return (
|
|
232
|
-
self.model_runner.req_to_token_pool,
|
|
233
|
-
self.model_runner.token_to_kv_pool_allocator,
|
|
234
|
-
)
|
|
235
|
-
|
|
236
340
|
def forward_batch_generation(
|
|
237
341
|
self,
|
|
238
342
|
model_worker_batch: ModelWorkerBatch,
|
|
239
|
-
|
|
343
|
+
forward_batch: Optional[ForwardBatch] = None,
|
|
240
344
|
is_verify: bool = False,
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
345
|
+
skip_attn_backend_init=False,
|
|
346
|
+
) -> GenerationBatchResult:
|
|
347
|
+
# FIXME(lsyin): maybe remove skip_attn_backend_init in forward_batch_generation,
|
|
348
|
+
# which requires preparing replay to always be in this function
|
|
244
349
|
|
|
245
|
-
|
|
350
|
+
if model_worker_batch is not None:
|
|
351
|
+
# update the consumer index of hicache to the running batch
|
|
352
|
+
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
|
|
353
|
+
|
|
354
|
+
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
|
355
|
+
else:
|
|
356
|
+
# FIXME(lsyin): unify the interface of forward_batch
|
|
357
|
+
assert forward_batch is not None
|
|
246
358
|
|
|
247
359
|
pp_proxy_tensors = None
|
|
248
360
|
if not self.pp_group.is_first_rank:
|
|
@@ -254,123 +366,62 @@ class TpModelWorker:
|
|
|
254
366
|
|
|
255
367
|
if self.pp_group.is_last_rank:
|
|
256
368
|
logits_output, can_run_cuda_graph = self.model_runner.forward(
|
|
257
|
-
forward_batch,
|
|
369
|
+
forward_batch,
|
|
370
|
+
pp_proxy_tensors=pp_proxy_tensors,
|
|
371
|
+
skip_attn_backend_init=skip_attn_backend_init,
|
|
258
372
|
)
|
|
259
|
-
|
|
260
|
-
launch_done.set()
|
|
261
|
-
|
|
262
|
-
skip_sample = is_verify or model_worker_batch.is_prefill_only
|
|
263
|
-
next_token_ids = None
|
|
264
|
-
|
|
265
|
-
if not skip_sample:
|
|
266
|
-
next_token_ids = self.model_runner.sample(logits_output, forward_batch)
|
|
267
|
-
elif model_worker_batch.return_logprob and not is_verify:
|
|
268
|
-
# NOTE: Compute logprobs without full sampling
|
|
269
|
-
self.model_runner.compute_logprobs_only(
|
|
270
|
-
logits_output, model_worker_batch
|
|
271
|
-
)
|
|
272
|
-
|
|
273
|
-
return ForwardBatchOutput(
|
|
373
|
+
batch_result = GenerationBatchResult(
|
|
274
374
|
logits_output=logits_output,
|
|
275
|
-
next_token_ids=next_token_ids,
|
|
276
375
|
can_run_cuda_graph=can_run_cuda_graph,
|
|
277
376
|
)
|
|
377
|
+
|
|
378
|
+
if is_verify:
|
|
379
|
+
# Skip sampling and return logits for target forward
|
|
380
|
+
return batch_result
|
|
381
|
+
|
|
382
|
+
if (
|
|
383
|
+
self.enable_overlap
|
|
384
|
+
and model_worker_batch.sampling_info.grammars is not None
|
|
385
|
+
):
|
|
386
|
+
|
|
387
|
+
def sample_batch_func():
|
|
388
|
+
batch_result.next_token_ids = self.model_runner.sample(
|
|
389
|
+
logits_output, forward_batch
|
|
390
|
+
)
|
|
391
|
+
return batch_result
|
|
392
|
+
|
|
393
|
+
batch_result.delay_sample_func = sample_batch_func
|
|
394
|
+
return batch_result
|
|
395
|
+
|
|
396
|
+
if model_worker_batch.is_prefill_only:
|
|
397
|
+
# For prefill-only requests, create dummy token IDs on CPU
|
|
398
|
+
# The size should match the batch size (number of sequences), not total tokens
|
|
399
|
+
batch_result.next_token_ids = torch.zeros(
|
|
400
|
+
len(model_worker_batch.seq_lens),
|
|
401
|
+
dtype=torch.long,
|
|
402
|
+
device=model_worker_batch.input_ids.device,
|
|
403
|
+
)
|
|
404
|
+
if (
|
|
405
|
+
model_worker_batch.return_logprob
|
|
406
|
+
and logits_output.next_token_logits is not None
|
|
407
|
+
):
|
|
408
|
+
# NOTE: Compute logprobs without full sampling
|
|
409
|
+
self.model_runner.compute_logprobs_only(
|
|
410
|
+
logits_output, model_worker_batch
|
|
411
|
+
)
|
|
412
|
+
else:
|
|
413
|
+
batch_result.next_token_ids = self.model_runner.sample(
|
|
414
|
+
logits_output, forward_batch
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
return batch_result
|
|
278
418
|
else:
|
|
279
419
|
pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward(
|
|
280
420
|
forward_batch,
|
|
281
421
|
pp_proxy_tensors=pp_proxy_tensors,
|
|
422
|
+
skip_attn_backend_init=skip_attn_backend_init,
|
|
282
423
|
)
|
|
283
|
-
return
|
|
284
|
-
|
|
424
|
+
return GenerationBatchResult(
|
|
425
|
+
pp_hidden_states_proxy_tensors=pp_proxy_tensors,
|
|
285
426
|
can_run_cuda_graph=can_run_cuda_graph,
|
|
286
427
|
)
|
|
287
|
-
|
|
288
|
-
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
|
289
|
-
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
|
290
|
-
logits_output, _ = self.model_runner.forward(forward_batch)
|
|
291
|
-
embeddings = logits_output.embeddings
|
|
292
|
-
return embeddings
|
|
293
|
-
|
|
294
|
-
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
|
295
|
-
success, message = self.model_runner.update_weights_from_disk(
|
|
296
|
-
recv_req.model_path, recv_req.load_format
|
|
297
|
-
)
|
|
298
|
-
return success, message
|
|
299
|
-
|
|
300
|
-
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
|
301
|
-
success, message = self.model_runner.init_weights_update_group(
|
|
302
|
-
recv_req.master_address,
|
|
303
|
-
recv_req.master_port,
|
|
304
|
-
recv_req.rank_offset,
|
|
305
|
-
recv_req.world_size,
|
|
306
|
-
recv_req.group_name,
|
|
307
|
-
recv_req.backend,
|
|
308
|
-
)
|
|
309
|
-
return success, message
|
|
310
|
-
|
|
311
|
-
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
|
|
312
|
-
success, message = self.model_runner.destroy_weights_update_group(
|
|
313
|
-
recv_req.group_name,
|
|
314
|
-
)
|
|
315
|
-
return success, message
|
|
316
|
-
|
|
317
|
-
def init_weights_send_group_for_remote_instance(
|
|
318
|
-
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
|
319
|
-
):
|
|
320
|
-
success, message = (
|
|
321
|
-
self.model_runner.init_weights_send_group_for_remote_instance(
|
|
322
|
-
recv_req.master_address,
|
|
323
|
-
recv_req.ports,
|
|
324
|
-
recv_req.group_rank,
|
|
325
|
-
recv_req.world_size,
|
|
326
|
-
recv_req.group_name,
|
|
327
|
-
recv_req.backend,
|
|
328
|
-
)
|
|
329
|
-
)
|
|
330
|
-
return success, message
|
|
331
|
-
|
|
332
|
-
def send_weights_to_remote_instance(
|
|
333
|
-
self, recv_req: SendWeightsToRemoteInstanceReqInput
|
|
334
|
-
):
|
|
335
|
-
success, message = self.model_runner.send_weights_to_remote_instance(
|
|
336
|
-
recv_req.master_address,
|
|
337
|
-
recv_req.ports,
|
|
338
|
-
recv_req.group_name,
|
|
339
|
-
)
|
|
340
|
-
return success, message
|
|
341
|
-
|
|
342
|
-
def update_weights_from_distributed(
|
|
343
|
-
self, recv_req: UpdateWeightsFromDistributedReqInput
|
|
344
|
-
):
|
|
345
|
-
success, message = self.model_runner.update_weights_from_distributed(
|
|
346
|
-
recv_req.names, recv_req.dtypes, recv_req.shapes, recv_req.group_name
|
|
347
|
-
)
|
|
348
|
-
return success, message
|
|
349
|
-
|
|
350
|
-
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
|
351
|
-
|
|
352
|
-
monkey_patch_torch_reductions()
|
|
353
|
-
success, message = self.model_runner.update_weights_from_tensor(
|
|
354
|
-
named_tensors=MultiprocessingSerializer.deserialize(
|
|
355
|
-
recv_req.serialized_named_tensors[self.tp_rank]
|
|
356
|
-
),
|
|
357
|
-
load_format=recv_req.load_format,
|
|
358
|
-
)
|
|
359
|
-
return success, message
|
|
360
|
-
|
|
361
|
-
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
|
362
|
-
parameter = self.model_runner.get_weights_by_name(
|
|
363
|
-
recv_req.name, recv_req.truncate_size
|
|
364
|
-
)
|
|
365
|
-
return parameter
|
|
366
|
-
|
|
367
|
-
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
|
|
368
|
-
result = self.model_runner.load_lora_adapter(recv_req.to_ref())
|
|
369
|
-
return result
|
|
370
|
-
|
|
371
|
-
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
|
|
372
|
-
result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
|
|
373
|
-
return result
|
|
374
|
-
|
|
375
|
-
def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
|
|
376
|
-
return self.model_runner.lora_manager.validate_lora_batch(lora_ids)
|
sglang/srt/managers/utils.py
CHANGED
|
@@ -1,19 +1,95 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import dataclasses
|
|
3
4
|
import logging
|
|
4
|
-
import
|
|
5
|
-
|
|
5
|
+
from typing import TYPE_CHECKING, List, Optional
|
|
6
|
+
|
|
7
|
+
import torch
|
|
6
8
|
|
|
7
9
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
|
10
|
+
from sglang.srt.managers.overlap_utils import FutureIndices
|
|
8
11
|
from sglang.srt.managers.schedule_batch import Req
|
|
9
12
|
from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
|
|
10
13
|
|
|
11
14
|
if TYPE_CHECKING:
|
|
12
15
|
from sglang.srt.managers.scheduler import GenerationBatchResult
|
|
16
|
+
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
|
17
|
+
|
|
13
18
|
|
|
14
19
|
logger = logging.getLogger(__name__)
|
|
15
20
|
|
|
16
21
|
|
|
22
|
+
@dataclasses.dataclass
|
|
23
|
+
class GenerationBatchResult:
|
|
24
|
+
logits_output: Optional[LogitsProcessorOutput] = None
|
|
25
|
+
pp_hidden_states_proxy_tensors: Optional[PPProxyTensors] = None
|
|
26
|
+
next_token_ids: Optional[torch.Tensor] = None
|
|
27
|
+
num_accepted_tokens: Optional[int] = None
|
|
28
|
+
can_run_cuda_graph: bool = False
|
|
29
|
+
|
|
30
|
+
# For output processing
|
|
31
|
+
extend_input_len_per_req: Optional[List[int]] = None
|
|
32
|
+
extend_logprob_start_len_per_req: Optional[List[int]] = None
|
|
33
|
+
|
|
34
|
+
# For overlap scheduling
|
|
35
|
+
copy_done: Optional[torch.cuda.Event] = None
|
|
36
|
+
delay_sample_func: Optional[callable] = None
|
|
37
|
+
future_indices: Optional[FutureIndices] = None
|
|
38
|
+
|
|
39
|
+
# FIXME(lsyin): maybe move to a better place?
|
|
40
|
+
# sync path: forward stream -> output processor
|
|
41
|
+
accept_lens: Optional[torch.Tensor] = None
|
|
42
|
+
allocate_lens: Optional[torch.Tensor] = None
|
|
43
|
+
|
|
44
|
+
# relay path: forward stream -> next step forward
|
|
45
|
+
next_draft_input: Optional[EagleDraftInput] = None
|
|
46
|
+
|
|
47
|
+
def copy_to_cpu(self, return_logprob: bool = False):
|
|
48
|
+
"""Copy tensors to CPU in overlap scheduling.
|
|
49
|
+
Only the tensors which are needed for processing results are copied,
|
|
50
|
+
e.g., next_token_ids, logits outputs
|
|
51
|
+
"""
|
|
52
|
+
if return_logprob:
|
|
53
|
+
if self.logits_output.next_token_logits is not None:
|
|
54
|
+
self.logits_output.next_token_logits = (
|
|
55
|
+
self.logits_output.next_token_logits.to("cpu", non_blocking=True)
|
|
56
|
+
)
|
|
57
|
+
if self.logits_output.input_token_logprobs is not None:
|
|
58
|
+
self.logits_output.input_token_logprobs = (
|
|
59
|
+
self.logits_output.input_token_logprobs.to("cpu", non_blocking=True)
|
|
60
|
+
)
|
|
61
|
+
if self.logits_output.hidden_states is not None:
|
|
62
|
+
self.logits_output.hidden_states = self.logits_output.hidden_states.to(
|
|
63
|
+
"cpu", non_blocking=True
|
|
64
|
+
)
|
|
65
|
+
self.next_token_ids = self.next_token_ids.to("cpu", non_blocking=True)
|
|
66
|
+
|
|
67
|
+
if self.accept_lens is not None:
|
|
68
|
+
self.accept_lens = self.accept_lens.to("cpu", non_blocking=True)
|
|
69
|
+
|
|
70
|
+
if self.allocate_lens is not None:
|
|
71
|
+
self.allocate_lens = self.allocate_lens.to("cpu", non_blocking=True)
|
|
72
|
+
|
|
73
|
+
self.copy_done.record()
|
|
74
|
+
|
|
75
|
+
@classmethod
|
|
76
|
+
def from_pp_proxy(
|
|
77
|
+
cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph
|
|
78
|
+
):
|
|
79
|
+
# TODO(lsyin): refactor PP and avoid using dict
|
|
80
|
+
proxy_dict = next_pp_outputs.tensors
|
|
81
|
+
return cls(
|
|
82
|
+
logits_output=logits_output,
|
|
83
|
+
pp_hidden_states_proxy_tensors=None,
|
|
84
|
+
next_token_ids=next_pp_outputs["next_token_ids"],
|
|
85
|
+
extend_input_len_per_req=proxy_dict.get("extend_input_len_per_req", None),
|
|
86
|
+
extend_logprob_start_len_per_req=proxy_dict.get(
|
|
87
|
+
"extend_logprob_start_len_per_req", None
|
|
88
|
+
),
|
|
89
|
+
can_run_cuda_graph=can_run_cuda_graph,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
|
|
17
93
|
def validate_input_length(
|
|
18
94
|
req: Req, max_req_input_len: int, allow_auto_truncate: bool
|
|
19
95
|
) -> Optional[str]:
|
|
@@ -274,10 +274,15 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
|
274
274
|
self.full_to_swa_index_mapping[free_index] = 0
|
|
275
275
|
|
|
276
276
|
def backup_state(self):
|
|
277
|
-
|
|
277
|
+
return [
|
|
278
|
+
self.full_attn_allocator.backup_state(),
|
|
279
|
+
self.swa_attn_allocator.backup_state(),
|
|
280
|
+
]
|
|
278
281
|
|
|
279
282
|
def restore_state(self, state):
|
|
280
|
-
|
|
283
|
+
assert len(state) == 2
|
|
284
|
+
self.full_attn_allocator.restore_state(state[0])
|
|
285
|
+
self.swa_attn_allocator.restore_state(state[1])
|
|
281
286
|
|
|
282
287
|
def clear(self):
|
|
283
288
|
self.swa_attn_allocator.clear()
|
|
@@ -92,7 +92,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|
|
92
92
|
)
|
|
93
93
|
|
|
94
94
|
if num_new_pages_item < 200:
|
|
95
|
-
import sgl_kernel_npu
|
|
95
|
+
import sgl_kernel_npu # noqa: F401
|
|
96
96
|
|
|
97
97
|
torch.ops.npu.alloc_extend(
|
|
98
98
|
prefix_lens,
|
|
@@ -119,7 +119,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|
|
119
119
|
assert len(torch.unique(out_indices)) == len(out_indices)
|
|
120
120
|
|
|
121
121
|
self.free_pages = self.free_pages[num_new_pages_item:]
|
|
122
|
-
return out_indices
|
|
122
|
+
return out_indices.int()
|
|
123
123
|
|
|
124
124
|
def alloc_decode(
|
|
125
125
|
self,
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
-
from typing import TYPE_CHECKING, Any,
|
|
2
|
+
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
@@ -40,7 +40,7 @@ class BasePrefixCache(ABC):
|
|
|
40
40
|
pass
|
|
41
41
|
|
|
42
42
|
@abstractmethod
|
|
43
|
-
def cache_finished_req(self, req: Req, **kwargs):
|
|
43
|
+
def cache_finished_req(self, req: Req, is_insert: bool = True, **kwargs):
|
|
44
44
|
pass
|
|
45
45
|
|
|
46
46
|
@abstractmethod
|
|
@@ -27,6 +27,12 @@ class ChunkCache(BasePrefixCache):
|
|
|
27
27
|
self.req_to_token_pool = req_to_token_pool
|
|
28
28
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
|
29
29
|
self.page_size = page_size
|
|
30
|
+
if self.token_to_kv_pool_allocator:
|
|
31
|
+
self.device = self.token_to_kv_pool_allocator.device
|
|
32
|
+
else:
|
|
33
|
+
self.device = torch.device("cpu")
|
|
34
|
+
|
|
35
|
+
self.protected_size_ = 0
|
|
30
36
|
|
|
31
37
|
# NOTE (csy): this is to determine if a cache has prefix matching feature.
|
|
32
38
|
# Chunk cache always return True to indicate no prefix matching.
|
|
@@ -45,7 +51,7 @@ class ChunkCache(BasePrefixCache):
|
|
|
45
51
|
last_host_node=None,
|
|
46
52
|
)
|
|
47
53
|
|
|
48
|
-
def cache_finished_req(self, req: Req,
|
|
54
|
+
def cache_finished_req(self, req: Req, is_insert: bool = True):
|
|
49
55
|
kv_indices = self.req_to_token_pool.req_to_token[
|
|
50
56
|
req.req_pool_idx,
|
|
51
57
|
# For decode server: if req.output_ids is empty, we want to free all req.origin_input_ids
|
|
@@ -53,14 +59,16 @@ class ChunkCache(BasePrefixCache):
|
|
|
53
59
|
]
|
|
54
60
|
self.req_to_token_pool.free(req.req_pool_idx)
|
|
55
61
|
self.token_to_kv_pool_allocator.free(kv_indices)
|
|
62
|
+
self.protected_size_ -= len(req.prefix_indices)
|
|
56
63
|
|
|
57
64
|
def cache_unfinished_req(self, req: Req, chunked=False):
|
|
58
65
|
kv_indices = self.req_to_token_pool.req_to_token[
|
|
59
66
|
req.req_pool_idx, : len(req.fill_ids)
|
|
60
67
|
]
|
|
68
|
+
self.protected_size_ += len(kv_indices) - len(req.prefix_indices)
|
|
61
69
|
|
|
62
70
|
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
|
63
|
-
req.prefix_indices = kv_indices
|
|
71
|
+
req.prefix_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
|
64
72
|
|
|
65
73
|
def evict(self, num_tokens: int):
|
|
66
74
|
pass
|
|
@@ -71,6 +79,9 @@ class ChunkCache(BasePrefixCache):
|
|
|
71
79
|
def dec_lock_ref(self, node: Any, swa_uuid_for_lock: Optional[str] = None):
|
|
72
80
|
return 0
|
|
73
81
|
|
|
82
|
+
def protected_size(self):
|
|
83
|
+
return self.protected_size_
|
|
84
|
+
|
|
74
85
|
def pretty_print(self):
|
|
75
86
|
return ""
|
|
76
87
|
|