sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +330 -156
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +8 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +134 -23
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +70 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +66 -66
- sglang/srt/entrypoints/grpc_server.py +431 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +120 -8
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +42 -4
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +18 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +4 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +12 -8
- sglang/srt/layers/attention/flashinfer_backend.py +248 -21
- sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +11 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +45 -15
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +147 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
- sglang/srt/layers/moe/ep_moe/layer.py +119 -397
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +17 -1
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +5 -30
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +673 -16
- sglang/srt/layers/sampler.py +36 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +66 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +399 -499
- sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +378 -90
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +435 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +28 -23
- sglang/srt/model_executor/model_runner.py +379 -139
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +273 -98
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +14 -37
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +5 -5
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3_moe.py +19 -35
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +6 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +577 -73
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +38 -28
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +192 -47
- sglang/srt/utils/hf_transformers_utils.py +40 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +232 -99
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from sglang.srt.disaggregation.utils import DisaggregationMode
|
|
7
|
+
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
|
8
|
+
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
|
|
9
|
+
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from sglang.srt.managers.scheduler import Scheduler
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SchedulerRuntimeCheckerMixin:
|
|
16
|
+
|
|
17
|
+
def _check_hybrid_memory(self: Scheduler):
|
|
18
|
+
(
|
|
19
|
+
full_num_used,
|
|
20
|
+
swa_num_used,
|
|
21
|
+
_,
|
|
22
|
+
_,
|
|
23
|
+
full_available_size,
|
|
24
|
+
full_evictable_size,
|
|
25
|
+
swa_available_size,
|
|
26
|
+
swa_evictable_size,
|
|
27
|
+
) = self._get_swa_token_info()
|
|
28
|
+
memory_leak = full_num_used != 0 or swa_num_used != 0
|
|
29
|
+
token_msg = (
|
|
30
|
+
f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
|
|
31
|
+
f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
|
|
32
|
+
)
|
|
33
|
+
return memory_leak, token_msg
|
|
34
|
+
|
|
35
|
+
def _check_mamba_memory(self: Scheduler):
|
|
36
|
+
(
|
|
37
|
+
full_num_used,
|
|
38
|
+
mamba_num_used,
|
|
39
|
+
_,
|
|
40
|
+
_,
|
|
41
|
+
full_available_size,
|
|
42
|
+
full_evictable_size,
|
|
43
|
+
mamba_available_size,
|
|
44
|
+
mamba_evictable_size,
|
|
45
|
+
) = self._get_mamba_token_info()
|
|
46
|
+
memory_leak = (
|
|
47
|
+
full_num_used != self.tree_cache.full_protected_size()
|
|
48
|
+
or mamba_num_used != self.tree_cache.mamba_protected_size()
|
|
49
|
+
)
|
|
50
|
+
token_msg = (
|
|
51
|
+
f"{full_available_size=}, {full_evictable_size=}, {self.token_to_kv_pool_allocator.size=}, {self.tree_cache.full_protected_size()=}\n"
|
|
52
|
+
f"{mamba_available_size=}, {mamba_evictable_size=}, {self.req_to_token_pool.mamba_pool.size=}, {self.tree_cache.mamba_protected_size()=}\n"
|
|
53
|
+
)
|
|
54
|
+
return memory_leak, token_msg
|
|
55
|
+
|
|
56
|
+
def _check_radix_cache_memory(self: Scheduler):
|
|
57
|
+
_, _, available_size, evictable_size = self._get_token_info()
|
|
58
|
+
protected_size = self.tree_cache.protected_size()
|
|
59
|
+
memory_leak = (available_size + evictable_size) != (
|
|
60
|
+
# self.max_total_num_tokens
|
|
61
|
+
# if not self.enable_hierarchical_cache
|
|
62
|
+
# else self.max_total_num_tokens - protected_size
|
|
63
|
+
self.max_total_num_tokens
|
|
64
|
+
- protected_size
|
|
65
|
+
)
|
|
66
|
+
token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
|
|
67
|
+
return memory_leak, token_msg
|
|
68
|
+
|
|
69
|
+
def _check_runtime_mem_leak(self: Scheduler):
|
|
70
|
+
current_batch: ScheduleBatch = self.last_batch
|
|
71
|
+
|
|
72
|
+
if current_batch is None:
|
|
73
|
+
return
|
|
74
|
+
|
|
75
|
+
_, _, available_size, evictable_size = self._get_token_info()
|
|
76
|
+
protected_size = self.tree_cache.protected_size()
|
|
77
|
+
|
|
78
|
+
extend_size = 0
|
|
79
|
+
for i, req in enumerate(current_batch.reqs):
|
|
80
|
+
seq_len = len(req.origin_input_ids) + len(req.output_ids)
|
|
81
|
+
fill_len = len(req.fill_ids) if req.fill_ids is not None else 0
|
|
82
|
+
prefix_len = (
|
|
83
|
+
len(req.prefix_indices) if req.prefix_indices is not None else 0
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
if current_batch.forward_mode.is_decode():
|
|
87
|
+
if req.finished():
|
|
88
|
+
unreleased_len = 1
|
|
89
|
+
else:
|
|
90
|
+
unreleased_len = seq_len - prefix_len
|
|
91
|
+
else:
|
|
92
|
+
unreleased_len = fill_len - prefix_len
|
|
93
|
+
|
|
94
|
+
extend_size += unreleased_len
|
|
95
|
+
|
|
96
|
+
if (
|
|
97
|
+
current_batch.forward_mode.is_extend()
|
|
98
|
+
and self.running_batch is not None
|
|
99
|
+
and not self.running_batch.is_empty()
|
|
100
|
+
and self.running_batch.forward_mode.is_decode()
|
|
101
|
+
):
|
|
102
|
+
for i, req in enumerate(self.running_batch.reqs):
|
|
103
|
+
seq_len = len(req.origin_input_ids) + len(req.output_ids)
|
|
104
|
+
prefix_len = (
|
|
105
|
+
len(req.prefix_indices) if req.prefix_indices is not None else 0
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
if req.finished():
|
|
109
|
+
unreleased_len = 0
|
|
110
|
+
else:
|
|
111
|
+
unreleased_len = seq_len - prefix_len - 1
|
|
112
|
+
|
|
113
|
+
extend_size += unreleased_len
|
|
114
|
+
|
|
115
|
+
total_tokens = available_size + evictable_size + protected_size + extend_size
|
|
116
|
+
|
|
117
|
+
assert (
|
|
118
|
+
total_tokens == self.max_total_num_tokens
|
|
119
|
+
), f"Mem Leak Detected! {total_tokens=} vs {self.max_total_num_tokens=}"
|
|
120
|
+
|
|
121
|
+
def _check_req_pool(self: Scheduler):
|
|
122
|
+
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
|
123
|
+
req_total_size = (
|
|
124
|
+
self.req_to_token_pool.size + self.req_to_token_pool.pre_alloc_size
|
|
125
|
+
)
|
|
126
|
+
else:
|
|
127
|
+
req_total_size = self.req_to_token_pool.size
|
|
128
|
+
|
|
129
|
+
if len(self.req_to_token_pool.free_slots) != req_total_size:
|
|
130
|
+
msg = (
|
|
131
|
+
"req_to_token_pool memory leak detected!"
|
|
132
|
+
f"available_size={len(self.req_to_token_pool.free_slots)}, "
|
|
133
|
+
f"total_size={self.req_to_token_pool.size}\n"
|
|
134
|
+
)
|
|
135
|
+
raise ValueError(msg)
|
|
136
|
+
|
|
137
|
+
def check_memory(self: Scheduler):
|
|
138
|
+
if self.is_hybrid:
|
|
139
|
+
memory_leak, token_msg = self._check_hybrid_memory()
|
|
140
|
+
elif self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache):
|
|
141
|
+
memory_leak, token_msg = self._check_mamba_memory()
|
|
142
|
+
else:
|
|
143
|
+
memory_leak, token_msg = self._check_radix_cache_memory()
|
|
144
|
+
|
|
145
|
+
if memory_leak:
|
|
146
|
+
msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}"
|
|
147
|
+
raise ValueError(msg)
|
|
148
|
+
|
|
149
|
+
self._check_req_pool()
|
|
150
|
+
|
|
151
|
+
if (
|
|
152
|
+
self.enable_metrics
|
|
153
|
+
and self.current_scheduler_metrics_enabled()
|
|
154
|
+
and time.perf_counter() > self.metrics_collector.last_log_time + 30
|
|
155
|
+
):
|
|
156
|
+
# During idle time, also collect metrics every 30 seconds.
|
|
157
|
+
if self.is_hybrid:
|
|
158
|
+
(
|
|
159
|
+
full_num_used,
|
|
160
|
+
swa_num_used,
|
|
161
|
+
full_token_usage,
|
|
162
|
+
swa_token_usage,
|
|
163
|
+
_,
|
|
164
|
+
_,
|
|
165
|
+
_,
|
|
166
|
+
_,
|
|
167
|
+
) = self._get_swa_token_info()
|
|
168
|
+
num_used = max(full_num_used, swa_num_used)
|
|
169
|
+
token_usage = max(full_token_usage, swa_token_usage)
|
|
170
|
+
elif self.is_hybrid_gdn:
|
|
171
|
+
(
|
|
172
|
+
num_used,
|
|
173
|
+
_,
|
|
174
|
+
token_usage,
|
|
175
|
+
_,
|
|
176
|
+
_,
|
|
177
|
+
_,
|
|
178
|
+
_,
|
|
179
|
+
_,
|
|
180
|
+
) = self._get_mamba_token_info()
|
|
181
|
+
else:
|
|
182
|
+
num_used, token_usage, _, _ = self._get_token_info()
|
|
183
|
+
num_running_reqs = len(self.running_batch.reqs)
|
|
184
|
+
self.stats.num_running_reqs = num_running_reqs
|
|
185
|
+
self.stats.num_used_tokens = num_used
|
|
186
|
+
self.stats.token_usage = round(token_usage, 2)
|
|
187
|
+
self.stats.gen_throughput = 0
|
|
188
|
+
self.stats.num_queue_reqs = len(self.waiting_queue)
|
|
189
|
+
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
|
190
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
|
191
|
+
self.stats.num_prefill_prealloc_queue_reqs = len(
|
|
192
|
+
self.disagg_prefill_bootstrap_queue.queue
|
|
193
|
+
)
|
|
194
|
+
self.stats.num_prefill_inflight_queue_reqs = len(
|
|
195
|
+
self.disagg_prefill_inflight_queue
|
|
196
|
+
)
|
|
197
|
+
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
|
198
|
+
self.stats.num_decode_prealloc_queue_reqs = len(
|
|
199
|
+
self.disagg_decode_prealloc_queue.queue
|
|
200
|
+
)
|
|
201
|
+
self.stats.num_decode_transfer_queue_reqs = len(
|
|
202
|
+
self.disagg_decode_transfer_queue.queue
|
|
203
|
+
)
|
|
204
|
+
self.metrics_collector.log_stats(self.stats)
|
|
205
|
+
self._publish_kv_events()
|
|
206
|
+
|
|
207
|
+
def check_tree_cache(self: Scheduler):
|
|
208
|
+
if (self.is_hybrid and isinstance(self.tree_cache, SWARadixCache)) or (
|
|
209
|
+
self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache)
|
|
210
|
+
):
|
|
211
|
+
self.tree_cache.sanity_check()
|
|
212
|
+
|
|
213
|
+
def self_check_during_idle(self: Scheduler):
|
|
214
|
+
self.check_memory()
|
|
215
|
+
self.check_tree_cache()
|
|
216
|
+
self.new_token_ratio = self.init_new_token_ratio
|
|
217
|
+
self.maybe_sleep_on_idle()
|
|
@@ -1,5 +1,7 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import logging
|
|
2
|
-
from typing import Tuple
|
|
4
|
+
from typing import TYPE_CHECKING, Tuple
|
|
3
5
|
|
|
4
6
|
import torch
|
|
5
7
|
|
|
@@ -19,10 +21,15 @@ from sglang.srt.managers.io_struct import (
|
|
|
19
21
|
UpdateWeightFromDiskReqOutput,
|
|
20
22
|
UpdateWeightsFromDistributedReqInput,
|
|
21
23
|
UpdateWeightsFromDistributedReqOutput,
|
|
24
|
+
UpdateWeightsFromIPCReqInput,
|
|
25
|
+
UpdateWeightsFromIPCReqOutput,
|
|
22
26
|
UpdateWeightsFromTensorReqInput,
|
|
23
27
|
UpdateWeightsFromTensorReqOutput,
|
|
24
28
|
)
|
|
25
29
|
|
|
30
|
+
if TYPE_CHECKING:
|
|
31
|
+
from sglang.srt.managers.scheduler import Scheduler
|
|
32
|
+
|
|
26
33
|
logger = logging.getLogger(__name__)
|
|
27
34
|
|
|
28
35
|
|
|
@@ -75,11 +82,25 @@ class SchedulerUpdateWeightsMixin:
|
|
|
75
82
|
torch.distributed.barrier(group=self.tp_cpu_group)
|
|
76
83
|
return UpdateWeightsFromTensorReqOutput(success, message)
|
|
77
84
|
|
|
85
|
+
def update_weights_from_ipc(self, recv_req: UpdateWeightsFromIPCReqInput):
|
|
86
|
+
"""Update the online model parameter from IPC for checkpoint-engine integration."""
|
|
87
|
+
success, message = self.tp_worker.update_weights_from_ipc(recv_req)
|
|
88
|
+
if success:
|
|
89
|
+
if recv_req.flush_cache:
|
|
90
|
+
flush_cache_success = self.flush_cache()
|
|
91
|
+
assert flush_cache_success, "Cache flush failed after updating weights"
|
|
92
|
+
else:
|
|
93
|
+
logger.error(message)
|
|
94
|
+
torch.distributed.barrier(group=self.tp_cpu_group)
|
|
95
|
+
return UpdateWeightsFromIPCReqOutput(success, message)
|
|
96
|
+
|
|
78
97
|
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
|
79
98
|
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
|
80
99
|
return GetWeightsByNameReqOutput(parameter)
|
|
81
100
|
|
|
82
|
-
def release_memory_occupation(
|
|
101
|
+
def release_memory_occupation(
|
|
102
|
+
self: Scheduler, recv_req: ReleaseMemoryOccupationReqInput
|
|
103
|
+
):
|
|
83
104
|
tags = recv_req.tags
|
|
84
105
|
|
|
85
106
|
if tags is None or len(tags) == 0:
|
|
@@ -94,14 +115,16 @@ class SchedulerUpdateWeightsMixin:
|
|
|
94
115
|
|
|
95
116
|
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
|
96
117
|
self.stashed_model_static_state = _export_static_state(
|
|
97
|
-
self.tp_worker.
|
|
118
|
+
self.tp_worker.model_runner.model
|
|
98
119
|
)
|
|
99
120
|
torch.distributed.barrier(self.tp_cpu_group)
|
|
100
121
|
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
|
|
101
122
|
|
|
102
123
|
return ReleaseMemoryOccupationReqOutput()
|
|
103
124
|
|
|
104
|
-
def resume_memory_occupation(
|
|
125
|
+
def resume_memory_occupation(
|
|
126
|
+
self: Scheduler, recv_req: ResumeMemoryOccupationReqInput
|
|
127
|
+
):
|
|
105
128
|
tags = recv_req.tags
|
|
106
129
|
|
|
107
130
|
if tags is None or len(tags) == 0:
|
|
@@ -114,7 +137,7 @@ class SchedulerUpdateWeightsMixin:
|
|
|
114
137
|
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
|
|
115
138
|
torch.distributed.barrier(self.tp_cpu_group)
|
|
116
139
|
_import_static_state(
|
|
117
|
-
self.tp_worker.
|
|
140
|
+
self.tp_worker.model_runner.model,
|
|
118
141
|
self.stashed_model_static_state,
|
|
119
142
|
)
|
|
120
143
|
del self.stashed_model_static_state
|
|
@@ -124,24 +147,20 @@ class SchedulerUpdateWeightsMixin:
|
|
|
124
147
|
|
|
125
148
|
return ResumeMemoryOccupationReqOutput()
|
|
126
149
|
|
|
127
|
-
def save_remote_model(self, params):
|
|
150
|
+
def save_remote_model(self: Scheduler, params):
|
|
128
151
|
url = params["url"]
|
|
129
152
|
|
|
130
|
-
|
|
131
|
-
worker.model_runner.save_remote_model(url)
|
|
153
|
+
self.tp_worker.model_runner.save_remote_model(url)
|
|
132
154
|
|
|
133
155
|
if self.draft_worker is not None:
|
|
134
156
|
draft_url = params.get("draft_url", None)
|
|
135
157
|
assert (
|
|
136
158
|
draft_url is not None
|
|
137
159
|
), "draft_url must be provided when draft model is enabled"
|
|
138
|
-
|
|
139
|
-
draft_worker.model_runner.save_remote_model(draft_url)
|
|
140
|
-
|
|
141
|
-
def save_sharded_model(self, params):
|
|
142
|
-
worker = self.tp_worker.worker
|
|
160
|
+
self.draft_worker.model_runner.save_remote_model(draft_url)
|
|
143
161
|
|
|
144
|
-
|
|
162
|
+
def save_sharded_model(self: Scheduler, params):
|
|
163
|
+
self.tp_worker.model_runner.save_sharded_model(
|
|
145
164
|
path=params["path"],
|
|
146
165
|
pattern=params["pattern"],
|
|
147
166
|
max_size=params["max_size"],
|
|
@@ -3,7 +3,6 @@ from __future__ import annotations
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import copy
|
|
5
5
|
import logging
|
|
6
|
-
import os
|
|
7
6
|
import time
|
|
8
7
|
import uuid
|
|
9
8
|
from collections import deque
|
|
@@ -46,7 +45,6 @@ from sglang.srt.managers.io_struct import (
|
|
|
46
45
|
LoadLoRAAdapterReqInput,
|
|
47
46
|
LoadLoRAAdapterReqOutput,
|
|
48
47
|
LoRAUpdateOutput,
|
|
49
|
-
MultiTokenizerWrapper,
|
|
50
48
|
OpenSessionReqInput,
|
|
51
49
|
ProfileReq,
|
|
52
50
|
ProfileReqOutput,
|
|
@@ -65,6 +63,8 @@ from sglang.srt.managers.io_struct import (
|
|
|
65
63
|
UnloadLoRAAdapterReqOutput,
|
|
66
64
|
UpdateWeightsFromDistributedReqInput,
|
|
67
65
|
UpdateWeightsFromDistributedReqOutput,
|
|
66
|
+
UpdateWeightsFromIPCReqInput,
|
|
67
|
+
UpdateWeightsFromIPCReqOutput,
|
|
68
68
|
UpdateWeightsFromTensorReqInput,
|
|
69
69
|
UpdateWeightsFromTensorReqOutput,
|
|
70
70
|
)
|
|
@@ -83,8 +83,6 @@ logger = logging.getLogger(__name__)
|
|
|
83
83
|
class _Communicator(Generic[T]):
|
|
84
84
|
"""Note: The communicator now only run up to 1 in-flight request at any time."""
|
|
85
85
|
|
|
86
|
-
enable_multi_tokenizer = False
|
|
87
|
-
|
|
88
86
|
def __init__(self, sender: zmq.Socket, fan_out: int, mode="queueing"):
|
|
89
87
|
self._sender = sender
|
|
90
88
|
self._fan_out = fan_out
|
|
@@ -104,8 +102,6 @@ class _Communicator(Generic[T]):
|
|
|
104
102
|
assert self._result_values is None
|
|
105
103
|
|
|
106
104
|
if obj:
|
|
107
|
-
if _Communicator.enable_multi_tokenizer:
|
|
108
|
-
obj = MultiTokenizerWrapper(worker_id=os.getpid(), obj=obj)
|
|
109
105
|
self._sender.send_pyobj(obj)
|
|
110
106
|
|
|
111
107
|
self._result_event = asyncio.Event()
|
|
@@ -126,8 +122,6 @@ class _Communicator(Generic[T]):
|
|
|
126
122
|
self._result_event = asyncio.Event()
|
|
127
123
|
|
|
128
124
|
if obj:
|
|
129
|
-
if _Communicator.enable_multi_tokenizer:
|
|
130
|
-
obj = MultiTokenizerWrapper(worker_id=os.getpid(), obj=obj)
|
|
131
125
|
self._sender.send_pyobj(obj)
|
|
132
126
|
|
|
133
127
|
await self._result_event.wait()
|
|
@@ -146,6 +140,13 @@ class _Communicator(Generic[T]):
|
|
|
146
140
|
if len(self._result_values) == self._fan_out:
|
|
147
141
|
self._result_event.set()
|
|
148
142
|
|
|
143
|
+
@staticmethod
|
|
144
|
+
def merge_results(results):
|
|
145
|
+
all_success = all([r.success for r in results])
|
|
146
|
+
all_message = [r.message for r in results]
|
|
147
|
+
all_message = " | ".join(all_message)
|
|
148
|
+
return all_success, all_message
|
|
149
|
+
|
|
149
150
|
|
|
150
151
|
class TokenizerCommunicatorMixin:
|
|
151
152
|
"""Mixin class for TokenizerManager to handle communication with the scheduler."""
|
|
@@ -170,6 +171,9 @@ class TokenizerCommunicatorMixin:
|
|
|
170
171
|
self.update_weights_from_tensor_communicator = _Communicator(
|
|
171
172
|
self.send_to_scheduler, server_args.dp_size
|
|
172
173
|
)
|
|
174
|
+
self.update_weights_from_ipc_communicator = _Communicator(
|
|
175
|
+
self.send_to_scheduler, server_args.dp_size
|
|
176
|
+
)
|
|
173
177
|
self.get_weights_by_name_communicator = _Communicator(
|
|
174
178
|
self.send_to_scheduler, server_args.dp_size
|
|
175
179
|
)
|
|
@@ -236,6 +240,10 @@ class TokenizerCommunicatorMixin:
|
|
|
236
240
|
UpdateWeightsFromTensorReqOutput,
|
|
237
241
|
self.update_weights_from_tensor_communicator.handle_recv,
|
|
238
242
|
),
|
|
243
|
+
(
|
|
244
|
+
UpdateWeightsFromIPCReqOutput,
|
|
245
|
+
self.update_weights_from_ipc_communicator.handle_recv,
|
|
246
|
+
),
|
|
239
247
|
(
|
|
240
248
|
GetWeightsByNameReqOutput,
|
|
241
249
|
self.get_weights_by_name_communicator.handle_recv,
|
|
@@ -306,6 +314,7 @@ class TokenizerCommunicatorMixin:
|
|
|
306
314
|
with_stack: Optional[bool] = None,
|
|
307
315
|
record_shapes: Optional[bool] = None,
|
|
308
316
|
profile_by_stage: bool = False,
|
|
317
|
+
merge_profiles: bool = False,
|
|
309
318
|
):
|
|
310
319
|
self.auto_create_handle_loop()
|
|
311
320
|
env_with_stack: bool = get_bool_env_var("SGLANG_PROFILE_WITH_STACK", "true")
|
|
@@ -320,6 +329,7 @@ class TokenizerCommunicatorMixin:
|
|
|
320
329
|
record_shapes=record_shapes,
|
|
321
330
|
profile_by_stage=profile_by_stage,
|
|
322
331
|
profile_id=str(time.time()),
|
|
332
|
+
merge_profiles=merge_profiles,
|
|
323
333
|
)
|
|
324
334
|
return await self._execute_profile(req)
|
|
325
335
|
|
|
@@ -356,10 +366,11 @@ class TokenizerCommunicatorMixin:
|
|
|
356
366
|
) -> Tuple[bool, str]:
|
|
357
367
|
self.auto_create_handle_loop()
|
|
358
368
|
assert (
|
|
359
|
-
self.server_args.dp_size == 1
|
|
360
|
-
), "dp_size must be 1
|
|
361
|
-
|
|
362
|
-
|
|
369
|
+
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
|
370
|
+
), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
|
|
371
|
+
|
|
372
|
+
results = await self.init_weights_update_group_communicator(obj)
|
|
373
|
+
return _Communicator.merge_results(results)
|
|
363
374
|
|
|
364
375
|
async def destroy_weights_update_group(
|
|
365
376
|
self,
|
|
@@ -368,10 +379,11 @@ class TokenizerCommunicatorMixin:
|
|
|
368
379
|
) -> Tuple[bool, str]:
|
|
369
380
|
self.auto_create_handle_loop()
|
|
370
381
|
assert (
|
|
371
|
-
self.server_args.dp_size == 1
|
|
372
|
-
), "dp_size must be 1 for destroy parameter update group"
|
|
373
|
-
|
|
374
|
-
|
|
382
|
+
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
|
383
|
+
), "dp_size must be 1 or dp attention must be enabled for destroy parameter update group"
|
|
384
|
+
|
|
385
|
+
results = await self.destroy_weights_update_group_communicator(obj)
|
|
386
|
+
return _Communicator.merge_results(results)
|
|
375
387
|
|
|
376
388
|
async def update_weights_from_distributed(
|
|
377
389
|
self: TokenizerManager,
|
|
@@ -389,8 +401,8 @@ class TokenizerCommunicatorMixin:
|
|
|
389
401
|
# This means that weight sync
|
|
390
402
|
# cannot run while requests are in progress.
|
|
391
403
|
async with self.model_update_lock.writer_lock:
|
|
392
|
-
|
|
393
|
-
return
|
|
404
|
+
results = await self.update_weights_from_distributed_communicator(obj)
|
|
405
|
+
return _Communicator.merge_results(results)
|
|
394
406
|
|
|
395
407
|
async def init_weights_send_group_for_remote_instance(
|
|
396
408
|
self,
|
|
@@ -439,6 +451,28 @@ class TokenizerCommunicatorMixin:
|
|
|
439
451
|
result = (await self.update_weights_from_tensor_communicator(obj))[0]
|
|
440
452
|
return result.success, result.message
|
|
441
453
|
|
|
454
|
+
async def update_weights_from_ipc(
|
|
455
|
+
self,
|
|
456
|
+
obj: UpdateWeightsFromIPCReqInput,
|
|
457
|
+
request: Optional[fastapi.Request] = None,
|
|
458
|
+
) -> Tuple[bool, str]:
|
|
459
|
+
"""Update weights via IPC for checkpoint-engine integration."""
|
|
460
|
+
self.auto_create_handle_loop()
|
|
461
|
+
try:
|
|
462
|
+
# For now, we only support single data parallel instance
|
|
463
|
+
assert (
|
|
464
|
+
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
|
465
|
+
), "dp_size must be 1 or dp attention must be enabled for update weights from IPC"
|
|
466
|
+
logger.info("Starting IPC weight update")
|
|
467
|
+
# This means that weight sync cannot run while requests are in progress.
|
|
468
|
+
async with self.model_update_lock.writer_lock:
|
|
469
|
+
result = (await self.update_weights_from_ipc_communicator(obj))[0]
|
|
470
|
+
return result.success, result.message
|
|
471
|
+
except Exception as e:
|
|
472
|
+
error_msg = f"IPC weight update failed: {str(e)}"
|
|
473
|
+
logger.error(error_msg)
|
|
474
|
+
return False, error_msg
|
|
475
|
+
|
|
442
476
|
async def load_lora_adapter(
|
|
443
477
|
self: TokenizerManager,
|
|
444
478
|
obj: LoadLoRAAdapterReqInput,
|
|
@@ -606,8 +640,6 @@ class TokenizerCommunicatorMixin:
|
|
|
606
640
|
elif obj.session_id in self.session_futures:
|
|
607
641
|
return None
|
|
608
642
|
|
|
609
|
-
if self.server_args.tokenizer_worker_num > 1:
|
|
610
|
-
obj = MultiTokenizerWrapper(self.worker_id, obj)
|
|
611
643
|
self.send_to_scheduler.send_pyobj(obj)
|
|
612
644
|
|
|
613
645
|
self.session_futures[obj.session_id] = asyncio.Future()
|
|
@@ -627,43 +659,27 @@ class TokenizerCommunicatorMixin:
|
|
|
627
659
|
if self.log_requests:
|
|
628
660
|
if self.log_requests_level == 0:
|
|
629
661
|
max_length = 1 << 30
|
|
630
|
-
skip_names =
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
)
|
|
641
|
-
out_skip_names = set(
|
|
642
|
-
[
|
|
643
|
-
"text",
|
|
644
|
-
"output_ids",
|
|
645
|
-
"embedding",
|
|
646
|
-
]
|
|
647
|
-
)
|
|
662
|
+
skip_names = {
|
|
663
|
+
"text",
|
|
664
|
+
"input_ids",
|
|
665
|
+
"input_embeds",
|
|
666
|
+
"image_data",
|
|
667
|
+
"audio_data",
|
|
668
|
+
"lora_path",
|
|
669
|
+
"sampling_params",
|
|
670
|
+
}
|
|
671
|
+
out_skip_names = {"text", "output_ids", "embedding"}
|
|
648
672
|
elif self.log_requests_level == 1:
|
|
649
673
|
max_length = 1 << 30
|
|
650
|
-
skip_names =
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
)
|
|
660
|
-
out_skip_names = set(
|
|
661
|
-
[
|
|
662
|
-
"text",
|
|
663
|
-
"output_ids",
|
|
664
|
-
"embedding",
|
|
665
|
-
]
|
|
666
|
-
)
|
|
674
|
+
skip_names = {
|
|
675
|
+
"text",
|
|
676
|
+
"input_ids",
|
|
677
|
+
"input_embeds",
|
|
678
|
+
"image_data",
|
|
679
|
+
"audio_data",
|
|
680
|
+
"lora_path",
|
|
681
|
+
}
|
|
682
|
+
out_skip_names = {"text", "output_ids", "embedding"}
|
|
667
683
|
elif self.log_requests_level == 2:
|
|
668
684
|
max_length = 2048
|
|
669
685
|
elif self.log_requests_level == 3:
|