sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
 - sglang/bench_one_batch_server.py +41 -25
 - sglang/bench_serving.py +330 -156
 - sglang/check_env.py +1 -1
 - sglang/compile_deep_gemm.py +6 -2
 - sglang/global_config.py +1 -25
 - sglang/lang/api.py +6 -0
 - sglang/lang/interpreter.py +1 -0
 - sglang/lang/ir.py +13 -0
 - sglang/launch_server.py +8 -15
 - sglang/profiler.py +18 -1
 - sglang/srt/_custom_ops.py +1 -1
 - sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
 - sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
 - sglang/srt/compilation/backend.py +437 -0
 - sglang/srt/compilation/compilation_config.py +20 -0
 - sglang/srt/compilation/compilation_counter.py +47 -0
 - sglang/srt/compilation/compile.py +210 -0
 - sglang/srt/compilation/compiler_interface.py +503 -0
 - sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
 - sglang/srt/compilation/fix_functionalization.py +134 -0
 - sglang/srt/compilation/fx_utils.py +83 -0
 - sglang/srt/compilation/inductor_pass.py +140 -0
 - sglang/srt/compilation/pass_manager.py +66 -0
 - sglang/srt/compilation/piecewise_context_manager.py +40 -0
 - sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
 - sglang/srt/configs/__init__.py +4 -0
 - sglang/srt/configs/deepseek_ocr.py +262 -0
 - sglang/srt/configs/deepseekvl2.py +194 -96
 - sglang/srt/configs/dots_vlm.py +2 -7
 - sglang/srt/configs/falcon_h1.py +13 -64
 - sglang/srt/configs/load_config.py +25 -2
 - sglang/srt/configs/mamba_utils.py +117 -0
 - sglang/srt/configs/model_config.py +134 -23
 - sglang/srt/configs/modelopt_config.py +30 -0
 - sglang/srt/configs/nemotron_h.py +286 -0
 - sglang/srt/configs/olmo3.py +105 -0
 - sglang/srt/configs/points_v15_chat.py +29 -0
 - sglang/srt/configs/qwen3_next.py +11 -47
 - sglang/srt/configs/qwen3_omni.py +613 -0
 - sglang/srt/configs/qwen3_vl.py +0 -10
 - sglang/srt/connector/remote_instance.py +1 -1
 - sglang/srt/constrained/base_grammar_backend.py +5 -1
 - sglang/srt/constrained/llguidance_backend.py +5 -0
 - sglang/srt/constrained/outlines_backend.py +1 -1
 - sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
 - sglang/srt/constrained/utils.py +12 -0
 - sglang/srt/constrained/xgrammar_backend.py +20 -11
 - sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
 - sglang/srt/disaggregation/base/conn.py +17 -4
 - sglang/srt/disaggregation/common/conn.py +4 -2
 - sglang/srt/disaggregation/decode.py +123 -31
 - sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
 - sglang/srt/disaggregation/fake/conn.py +11 -3
 - sglang/srt/disaggregation/mooncake/conn.py +157 -19
 - sglang/srt/disaggregation/nixl/conn.py +69 -24
 - sglang/srt/disaggregation/prefill.py +96 -270
 - sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
 - sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
 - sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
 - sglang/srt/distributed/device_communicators/pynccl.py +24 -12
 - sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
 - sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
 - sglang/srt/distributed/naive_distributed.py +5 -4
 - sglang/srt/distributed/parallel_state.py +70 -19
 - sglang/srt/elastic_ep/elastic_ep.py +74 -0
 - sglang/srt/entrypoints/context.py +3 -2
 - sglang/srt/entrypoints/engine.py +66 -66
 - sglang/srt/entrypoints/grpc_server.py +431 -234
 - sglang/srt/entrypoints/harmony_utils.py +2 -2
 - sglang/srt/entrypoints/http_server.py +120 -8
 - sglang/srt/entrypoints/http_server_engine.py +1 -7
 - sglang/srt/entrypoints/openai/protocol.py +225 -37
 - sglang/srt/entrypoints/openai/serving_base.py +49 -2
 - sglang/srt/entrypoints/openai/serving_chat.py +29 -74
 - sglang/srt/entrypoints/openai/serving_classify.py +204 -0
 - sglang/srt/entrypoints/openai/serving_completions.py +15 -1
 - sglang/srt/entrypoints/openai/serving_responses.py +5 -2
 - sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
 - sglang/srt/environ.py +42 -4
 - sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
 - sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
 - sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
 - sglang/srt/eplb/expert_distribution.py +3 -4
 - sglang/srt/eplb/expert_location_dispatch.py +2 -2
 - sglang/srt/eplb/expert_location_updater.py +2 -2
 - sglang/srt/function_call/base_format_detector.py +17 -18
 - sglang/srt/function_call/function_call_parser.py +18 -14
 - sglang/srt/function_call/glm4_moe_detector.py +1 -5
 - sglang/srt/function_call/gpt_oss_detector.py +1 -1
 - sglang/srt/function_call/json_array_parser.py +0 -2
 - sglang/srt/function_call/utils.py +2 -2
 - sglang/srt/grpc/compile_proto.py +3 -3
 - sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
 - sglang/srt/grpc/health_servicer.py +189 -0
 - sglang/srt/grpc/scheduler_launcher.py +181 -0
 - sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
 - sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
 - sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
 - sglang/srt/layers/activation.py +4 -1
 - sglang/srt/layers/attention/aiter_backend.py +3 -3
 - sglang/srt/layers/attention/ascend_backend.py +17 -1
 - sglang/srt/layers/attention/attention_registry.py +43 -23
 - sglang/srt/layers/attention/base_attn_backend.py +20 -1
 - sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
 - sglang/srt/layers/attention/fla/chunk.py +0 -1
 - sglang/srt/layers/attention/fla/chunk_o.py +1 -1
 - sglang/srt/layers/attention/fla/index.py +0 -2
 - sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
 - sglang/srt/layers/attention/fla/utils.py +0 -3
 - sglang/srt/layers/attention/fla/wy_fast.py +0 -2
 - sglang/srt/layers/attention/flashattention_backend.py +12 -8
 - sglang/srt/layers/attention/flashinfer_backend.py +248 -21
 - sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
 - sglang/srt/layers/attention/flashmla_backend.py +2 -2
 - sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
 - sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
 - sglang/srt/layers/attention/intel_amx_backend.py +1 -1
 - sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
 - sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
 - sglang/srt/layers/attention/mamba/mamba.py +189 -241
 - sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
 - sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
 - sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
 - sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
 - sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
 - sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
 - sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
 - sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
 - sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
 - sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
 - sglang/srt/layers/attention/nsa/utils.py +0 -1
 - sglang/srt/layers/attention/nsa_backend.py +404 -90
 - sglang/srt/layers/attention/triton_backend.py +208 -34
 - sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
 - sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
 - sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
 - sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
 - sglang/srt/layers/attention/utils.py +11 -7
 - sglang/srt/layers/attention/vision.py +3 -3
 - sglang/srt/layers/attention/xpu_backend.py +1028 -0
 - sglang/srt/layers/communicator.py +11 -7
 - sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
 - sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
 - sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
 - sglang/srt/layers/dp_attention.py +17 -0
 - sglang/srt/layers/layernorm.py +45 -15
 - sglang/srt/layers/linear.py +9 -1
 - sglang/srt/layers/logits_processor.py +147 -17
 - sglang/srt/layers/modelopt_utils.py +11 -0
 - sglang/srt/layers/moe/cutlass_moe.py +0 -2
 - sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
 - sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
 - sglang/srt/layers/moe/ep_moe/layer.py +119 -397
 - sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
 - sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
 - sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
 - sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
 - sglang/srt/layers/moe/moe_runner/runner.py +3 -0
 - sglang/srt/layers/moe/moe_runner/triton.py +3 -1
 - sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
 - sglang/srt/layers/moe/router.py +51 -15
 - sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
 - sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
 - sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
 - sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
 - sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
 - sglang/srt/layers/moe/topk.py +3 -2
 - sglang/srt/layers/moe/utils.py +17 -1
 - sglang/srt/layers/quantization/__init__.py +2 -53
 - sglang/srt/layers/quantization/awq.py +183 -6
 - sglang/srt/layers/quantization/awq_triton.py +29 -0
 - sglang/srt/layers/quantization/base_config.py +20 -1
 - sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
 - sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
 - sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
 - sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
 - sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
 - sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
 - sglang/srt/layers/quantization/fp8.py +84 -18
 - sglang/srt/layers/quantization/fp8_kernel.py +55 -10
 - sglang/srt/layers/quantization/fp8_utils.py +42 -14
 - sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
 - sglang/srt/layers/quantization/gptq.py +0 -1
 - sglang/srt/layers/quantization/int8_kernel.py +18 -2
 - sglang/srt/layers/quantization/marlin_utils.py +12 -0
 - sglang/srt/layers/quantization/modelopt_quant.py +125 -100
 - sglang/srt/layers/quantization/mxfp4.py +5 -30
 - sglang/srt/layers/quantization/petit.py +1 -1
 - sglang/srt/layers/quantization/quark/quark.py +3 -1
 - sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
 - sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
 - sglang/srt/layers/quantization/unquant.py +1 -4
 - sglang/srt/layers/quantization/utils.py +0 -1
 - sglang/srt/layers/quantization/w4afp8.py +51 -20
 - sglang/srt/layers/quantization/w8a8_int8.py +30 -24
 - sglang/srt/layers/radix_attention.py +59 -9
 - sglang/srt/layers/rotary_embedding.py +673 -16
 - sglang/srt/layers/sampler.py +36 -16
 - sglang/srt/layers/sparse_pooler.py +98 -0
 - sglang/srt/layers/utils.py +0 -1
 - sglang/srt/layers/vocab_parallel_embedding.py +4 -1
 - sglang/srt/lora/backend/triton_backend.py +0 -1
 - sglang/srt/lora/eviction_policy.py +139 -0
 - sglang/srt/lora/lora_manager.py +24 -9
 - sglang/srt/lora/lora_registry.py +1 -1
 - sglang/srt/lora/mem_pool.py +40 -16
 - sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
 - sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
 - sglang/srt/managers/cache_controller.py +48 -17
 - sglang/srt/managers/data_parallel_controller.py +146 -42
 - sglang/srt/managers/detokenizer_manager.py +40 -13
 - sglang/srt/managers/io_struct.py +66 -16
 - sglang/srt/managers/mm_utils.py +20 -18
 - sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
 - sglang/srt/managers/overlap_utils.py +96 -19
 - sglang/srt/managers/schedule_batch.py +241 -511
 - sglang/srt/managers/schedule_policy.py +15 -2
 - sglang/srt/managers/scheduler.py +399 -499
 - sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
 - sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
 - sglang/srt/managers/scheduler_pp_mixin.py +341 -0
 - sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
 - sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
 - sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
 - sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
 - sglang/srt/managers/tokenizer_manager.py +378 -90
 - sglang/srt/managers/tp_worker.py +212 -161
 - sglang/srt/managers/utils.py +78 -2
 - sglang/srt/mem_cache/allocator.py +7 -2
 - sglang/srt/mem_cache/allocator_ascend.py +2 -2
 - sglang/srt/mem_cache/base_prefix_cache.py +2 -2
 - sglang/srt/mem_cache/chunk_cache.py +13 -2
 - sglang/srt/mem_cache/common.py +480 -0
 - sglang/srt/mem_cache/evict_policy.py +16 -1
 - sglang/srt/mem_cache/hicache_storage.py +4 -1
 - sglang/srt/mem_cache/hiradix_cache.py +16 -3
 - sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
 - sglang/srt/mem_cache/memory_pool.py +435 -219
 - sglang/srt/mem_cache/memory_pool_host.py +0 -1
 - sglang/srt/mem_cache/multimodal_cache.py +0 -1
 - sglang/srt/mem_cache/radix_cache.py +53 -19
 - sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
 - sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
 - sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
 - sglang/srt/mem_cache/storage/backend_factory.py +2 -2
 - sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
 - sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
 - sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
 - sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
 - sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
 - sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
 - sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
 - sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
 - sglang/srt/mem_cache/swa_radix_cache.py +92 -26
 - sglang/srt/metrics/collector.py +31 -0
 - sglang/srt/metrics/func_timer.py +1 -1
 - sglang/srt/model_executor/cuda_graph_runner.py +43 -5
 - sglang/srt/model_executor/forward_batch_info.py +28 -23
 - sglang/srt/model_executor/model_runner.py +379 -139
 - sglang/srt/model_executor/npu_graph_runner.py +2 -3
 - sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
 - sglang/srt/model_loader/__init__.py +1 -1
 - sglang/srt/model_loader/loader.py +424 -27
 - sglang/srt/model_loader/utils.py +0 -1
 - sglang/srt/model_loader/weight_utils.py +47 -28
 - sglang/srt/models/apertus.py +2 -3
 - sglang/srt/models/arcee.py +2 -2
 - sglang/srt/models/bailing_moe.py +13 -52
 - sglang/srt/models/bailing_moe_nextn.py +3 -4
 - sglang/srt/models/bert.py +1 -1
 - sglang/srt/models/deepseek_nextn.py +19 -3
 - sglang/srt/models/deepseek_ocr.py +1516 -0
 - sglang/srt/models/deepseek_v2.py +273 -98
 - sglang/srt/models/dots_ocr.py +0 -2
 - sglang/srt/models/dots_vlm.py +0 -1
 - sglang/srt/models/dots_vlm_vit.py +1 -1
 - sglang/srt/models/falcon_h1.py +13 -19
 - sglang/srt/models/gemma3_mm.py +16 -0
 - sglang/srt/models/gemma3n_mm.py +1 -2
 - sglang/srt/models/glm4_moe.py +14 -37
 - sglang/srt/models/glm4_moe_nextn.py +2 -2
 - sglang/srt/models/glm4v.py +2 -1
 - sglang/srt/models/glm4v_moe.py +5 -5
 - sglang/srt/models/gpt_oss.py +5 -5
 - sglang/srt/models/grok.py +10 -23
 - sglang/srt/models/hunyuan.py +2 -7
 - sglang/srt/models/interns1.py +0 -1
 - sglang/srt/models/kimi_vl.py +1 -7
 - sglang/srt/models/kimi_vl_moonvit.py +3 -1
 - sglang/srt/models/llama.py +2 -2
 - sglang/srt/models/llama_eagle3.py +1 -1
 - sglang/srt/models/longcat_flash.py +5 -22
 - sglang/srt/models/longcat_flash_nextn.py +3 -14
 - sglang/srt/models/mimo.py +2 -13
 - sglang/srt/models/mimo_mtp.py +1 -2
 - sglang/srt/models/minicpmo.py +7 -5
 - sglang/srt/models/mixtral.py +1 -4
 - sglang/srt/models/mllama.py +1 -1
 - sglang/srt/models/mllama4.py +13 -3
 - sglang/srt/models/nemotron_h.py +511 -0
 - sglang/srt/models/olmo2.py +31 -4
 - sglang/srt/models/opt.py +5 -5
 - sglang/srt/models/phi.py +1 -1
 - sglang/srt/models/phi4mm.py +1 -1
 - sglang/srt/models/phimoe.py +0 -1
 - sglang/srt/models/pixtral.py +0 -3
 - sglang/srt/models/points_v15_chat.py +186 -0
 - sglang/srt/models/qwen.py +0 -1
 - sglang/srt/models/qwen2_5_vl.py +3 -3
 - sglang/srt/models/qwen2_audio.py +2 -15
 - sglang/srt/models/qwen2_moe.py +15 -12
 - sglang/srt/models/qwen2_vl.py +5 -2
 - sglang/srt/models/qwen3_moe.py +19 -35
 - sglang/srt/models/qwen3_next.py +7 -12
 - sglang/srt/models/qwen3_next_mtp.py +3 -4
 - sglang/srt/models/qwen3_omni_moe.py +661 -0
 - sglang/srt/models/qwen3_vl.py +37 -33
 - sglang/srt/models/qwen3_vl_moe.py +57 -185
 - sglang/srt/models/roberta.py +55 -3
 - sglang/srt/models/sarashina2_vision.py +0 -1
 - sglang/srt/models/step3_vl.py +3 -5
 - sglang/srt/models/utils.py +11 -1
 - sglang/srt/multimodal/processors/base_processor.py +6 -2
 - sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
 - sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
 - sglang/srt/multimodal/processors/dots_vlm.py +0 -1
 - sglang/srt/multimodal/processors/glm4v.py +1 -5
 - sglang/srt/multimodal/processors/internvl.py +0 -2
 - sglang/srt/multimodal/processors/janus_pro.py +0 -1
 - sglang/srt/multimodal/processors/mllama4.py +0 -8
 - sglang/srt/multimodal/processors/phi4mm.py +0 -1
 - sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
 - sglang/srt/multimodal/processors/qwen_vl.py +75 -16
 - sglang/srt/multimodal/processors/step3_vl.py +1 -1
 - sglang/srt/parser/conversation.py +41 -0
 - sglang/srt/parser/reasoning_parser.py +0 -1
 - sglang/srt/sampling/custom_logit_processor.py +77 -2
 - sglang/srt/sampling/sampling_batch_info.py +17 -22
 - sglang/srt/sampling/sampling_params.py +70 -2
 - sglang/srt/server_args.py +577 -73
 - sglang/srt/server_args_config_parser.py +1 -1
 - sglang/srt/single_batch_overlap.py +38 -28
 - sglang/srt/speculative/base_spec_worker.py +34 -0
 - sglang/srt/speculative/draft_utils.py +226 -0
 - sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
 - sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
 - sglang/srt/speculative/eagle_info.py +57 -18
 - sglang/srt/speculative/eagle_info_v2.py +458 -0
 - sglang/srt/speculative/eagle_utils.py +138 -0
 - sglang/srt/speculative/eagle_worker.py +83 -280
 - sglang/srt/speculative/eagle_worker_v2.py +702 -0
 - sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
 - sglang/srt/speculative/ngram_worker.py +12 -11
 - sglang/srt/speculative/spec_info.py +2 -0
 - sglang/srt/speculative/spec_utils.py +38 -3
 - sglang/srt/speculative/standalone_worker.py +4 -14
 - sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
 - sglang/srt/two_batch_overlap.py +28 -14
 - sglang/srt/utils/__init__.py +1 -1
 - sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
 - sglang/srt/utils/common.py +192 -47
 - sglang/srt/utils/hf_transformers_utils.py +40 -17
 - sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
 - sglang/srt/{offloader.py → utils/offloader.py} +4 -4
 - sglang/srt/utils/profile_merger.py +199 -0
 - sglang/test/attention/test_flashattn_backend.py +1 -1
 - sglang/test/attention/test_flashattn_mla_backend.py +0 -1
 - sglang/test/attention/test_prefix_chunk_info.py +0 -2
 - sglang/test/attention/test_trtllm_mla_backend.py +221 -53
 - sglang/test/few_shot_gsm8k_engine.py +2 -4
 - sglang/test/kit_matched_stop.py +157 -0
 - sglang/test/longbench_v2/__init__.py +1 -0
 - sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
 - sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
 - sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
 - sglang/test/run_eval.py +41 -0
 - sglang/test/runners.py +2 -0
 - sglang/test/send_one.py +42 -7
 - sglang/test/simple_eval_common.py +3 -0
 - sglang/test/simple_eval_gpqa.py +0 -1
 - sglang/test/simple_eval_humaneval.py +0 -3
 - sglang/test/simple_eval_longbench_v2.py +344 -0
 - sglang/test/test_block_fp8.py +1 -2
 - sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
 - sglang/test/test_cutlass_moe.py +1 -2
 - sglang/test/test_cutlass_w4a8_moe.py +10 -20
 - sglang/test/test_deterministic.py +232 -99
 - sglang/test/test_deterministic_utils.py +73 -0
 - sglang/test/test_disaggregation_utils.py +81 -0
 - sglang/test/test_marlin_moe.py +0 -1
 - sglang/test/test_utils.py +85 -20
 - sglang/version.py +1 -1
 - {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
 - {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
 - sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
 - sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
 - sglang/srt/speculative/build_eagle_tree.py +0 -427
 - sglang/test/test_block_fp8_ep.py +0 -358
 - /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
 - /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
 - /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
 - {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
 - {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
 - {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
 
    
        sglang/srt/managers/io_struct.py
    CHANGED
    
    | 
         @@ -36,10 +36,10 @@ else: 
     | 
|
| 
       36 
36 
     | 
    
         
             
                Image = Any
         
     | 
| 
       37 
37 
     | 
    
         | 
| 
       38 
38 
     | 
    
         | 
| 
       39 
     | 
    
         
            -
            # Parameters for a session
         
     | 
| 
       40 
39 
     | 
    
         
             
            @dataclass
         
     | 
| 
       41 
40 
     | 
    
         
             
            class BaseReq(ABC):
         
     | 
| 
       42 
41 
     | 
    
         
             
                rid: Optional[Union[str, List[str]]] = field(default=None, kw_only=True)
         
     | 
| 
      
 42 
     | 
    
         
            +
                http_worker_ipc: Optional[str] = field(default=None, kw_only=True)
         
     | 
| 
       43 
43 
     | 
    
         | 
| 
       44 
44 
     | 
    
         
             
                def regenerate_rid(self):
         
     | 
| 
       45 
45 
     | 
    
         
             
                    """Generate a new request ID and return it."""
         
     | 
| 
         @@ -53,6 +53,7 @@ class BaseReq(ABC): 
     | 
|
| 
       53 
53 
     | 
    
         
             
            @dataclass
         
     | 
| 
       54 
54 
     | 
    
         
             
            class BaseBatchReq(ABC):
         
     | 
| 
       55 
55 
     | 
    
         
             
                rids: Optional[List[str]] = field(default=None, kw_only=True)
         
     | 
| 
      
 56 
     | 
    
         
            +
                http_worker_ipcs: Optional[List[str]] = field(default=None, kw_only=True)
         
     | 
| 
       56 
57 
     | 
    
         | 
| 
       57 
58 
     | 
    
         
             
                def regenerate_rids(self):
         
     | 
| 
       58 
59 
     | 
    
         
             
                    """Generate new request IDs and return them."""
         
     | 
| 
         @@ -60,9 +61,11 @@ class BaseBatchReq(ABC): 
     | 
|
| 
       60 
61 
     | 
    
         
             
                    return self.rids
         
     | 
| 
       61 
62 
     | 
    
         | 
| 
       62 
63 
     | 
    
         | 
| 
      
 64 
     | 
    
         
            +
            # Parameters for a session
         
     | 
| 
       63 
65 
     | 
    
         
             
            @dataclass
         
     | 
| 
       64 
66 
     | 
    
         
             
            class SessionParams:
         
     | 
| 
       65 
67 
     | 
    
         
             
                id: Optional[str] = None
         
     | 
| 
      
 68 
     | 
    
         
            +
                rid: Optional[str] = None
         
     | 
| 
       66 
69 
     | 
    
         
             
                offset: Optional[int] = None
         
     | 
| 
       67 
70 
     | 
    
         
             
                replace: Optional[bool] = None
         
     | 
| 
       68 
71 
     | 
    
         
             
                drop_previous_output: Optional[bool] = None
         
     | 
| 
         @@ -169,6 +172,9 @@ class GenerateReqInput(BaseReq): 
     | 
|
| 
       169 
172 
     | 
    
         
             
                # (Internal) Whether to return bytes for image generation
         
     | 
| 
       170 
173 
     | 
    
         
             
                return_bytes: bool = False
         
     | 
| 
       171 
174 
     | 
    
         | 
| 
      
 175 
     | 
    
         
            +
                # Whether to return entropy
         
     | 
| 
      
 176 
     | 
    
         
            +
                return_entropy: bool = False
         
     | 
| 
      
 177 
     | 
    
         
            +
             
     | 
| 
       172 
178 
     | 
    
         
             
                def contains_mm_input(self) -> bool:
         
     | 
| 
       173 
179 
     | 
    
         
             
                    return (
         
     | 
| 
       174 
180 
     | 
    
         
             
                        has_valid_data(self.image_data)
         
     | 
| 
         @@ -567,6 +573,7 @@ class GenerateReqInput(BaseReq): 
     | 
|
| 
       567 
573 
     | 
    
         
             
                        no_logs=self.no_logs,
         
     | 
| 
       568 
574 
     | 
    
         
             
                        custom_labels=self.custom_labels,
         
     | 
| 
       569 
575 
     | 
    
         
             
                        return_bytes=self.return_bytes,
         
     | 
| 
      
 576 
     | 
    
         
            +
                        return_entropy=self.return_entropy,
         
     | 
| 
       570 
577 
     | 
    
         
             
                    )
         
     | 
| 
       571 
578 
     | 
    
         | 
| 
       572 
579 
     | 
    
         | 
| 
         @@ -632,6 +639,9 @@ class TokenizedGenerateReqInput(BaseReq): 
     | 
|
| 
       632 
639 
     | 
    
         
             
                # (Internal) Whether to return bytes for image generation
         
     | 
| 
       633 
640 
     | 
    
         
             
                return_bytes: bool = False
         
     | 
| 
       634 
641 
     | 
    
         | 
| 
      
 642 
     | 
    
         
            +
                # Whether to return entropy
         
     | 
| 
      
 643 
     | 
    
         
            +
                return_entropy: bool = False
         
     | 
| 
      
 644 
     | 
    
         
            +
             
     | 
| 
       635 
645 
     | 
    
         | 
| 
       636 
646 
     | 
    
         
             
            @dataclass
         
     | 
| 
       637 
647 
     | 
    
         
             
            class BatchTokenizedGenerateReqInput(BaseBatchReq):
         
     | 
| 
         @@ -815,6 +825,7 @@ class BatchTokenIDOutput(BaseBatchReq): 
     | 
|
| 
       815 
825 
     | 
    
         
             
                completion_tokens: List[int]
         
     | 
| 
       816 
826 
     | 
    
         
             
                cached_tokens: List[int]
         
     | 
| 
       817 
827 
     | 
    
         
             
                spec_verify_ct: List[int]
         
     | 
| 
      
 828 
     | 
    
         
            +
                spec_accepted_tokens: List[int]
         
     | 
| 
       818 
829 
     | 
    
         | 
| 
       819 
830 
     | 
    
         
             
                # Logprobs
         
     | 
| 
       820 
831 
     | 
    
         
             
                input_token_logprobs_val: List[float]
         
     | 
| 
         @@ -829,6 +840,7 @@ class BatchTokenIDOutput(BaseBatchReq): 
     | 
|
| 
       829 
840 
     | 
    
         
             
                input_token_ids_logprobs_idx: List[List]
         
     | 
| 
       830 
841 
     | 
    
         
             
                output_token_ids_logprobs_val: List[List]
         
     | 
| 
       831 
842 
     | 
    
         
             
                output_token_ids_logprobs_idx: List[List]
         
     | 
| 
      
 843 
     | 
    
         
            +
                output_token_entropy_val: List[float]
         
     | 
| 
       832 
844 
     | 
    
         | 
| 
       833 
845 
     | 
    
         
             
                # Hidden states
         
     | 
| 
       834 
846 
     | 
    
         
             
                output_hidden_states: List[List[float]]
         
     | 
| 
         @@ -839,6 +851,9 @@ class BatchTokenIDOutput(BaseBatchReq): 
     | 
|
| 
       839 
851 
     | 
    
         
             
                placeholder_tokens_idx: List[Optional[List[int]]]
         
     | 
| 
       840 
852 
     | 
    
         
             
                placeholder_tokens_val: List[Optional[List[int]]]
         
     | 
| 
       841 
853 
     | 
    
         | 
| 
      
 854 
     | 
    
         
            +
                # The trainer step id. Used to know which step's weights are used for sampling.
         
     | 
| 
      
 855 
     | 
    
         
            +
                token_steps: List[List[int]] = None
         
     | 
| 
      
 856 
     | 
    
         
            +
             
     | 
| 
       842 
857 
     | 
    
         | 
| 
       843 
858 
     | 
    
         
             
            @dataclass
         
     | 
| 
       844 
859 
     | 
    
         
             
            class BatchMultimodalDecodeReq(BaseBatchReq):
         
     | 
| 
         @@ -860,11 +875,16 @@ class BatchMultimodalDecodeReq(BaseBatchReq): 
     | 
|
| 
       860 
875 
     | 
    
         
             
                completion_tokens: List[int]
         
     | 
| 
       861 
876 
     | 
    
         
             
                cached_tokens: List[int]
         
     | 
| 
       862 
877 
     | 
    
         | 
| 
       863 
     | 
    
         
            -
                #  
     | 
| 
      
 878 
     | 
    
         
            +
                # The information of placeholder tokens (e.g., image token)
         
     | 
| 
      
 879 
     | 
    
         
            +
                # idx is the index of the token in the prompt after expansion.
         
     | 
| 
      
 880 
     | 
    
         
            +
                # val is the length of padded tokens after expansion.
         
     | 
| 
       864 
881 
     | 
    
         
             
                placeholder_tokens_idx: List[Optional[List[int]]]
         
     | 
| 
       865 
882 
     | 
    
         
             
                placeholder_tokens_val: List[Optional[List[int]]]
         
     | 
| 
       866 
883 
     | 
    
         | 
| 
       867 
     | 
    
         
            -
                return_bytes: bool 
     | 
| 
      
 884 
     | 
    
         
            +
                return_bytes: List[bool]
         
     | 
| 
      
 885 
     | 
    
         
            +
             
     | 
| 
      
 886 
     | 
    
         
            +
                # The trainer step id. Used to know which step's weights are used for sampling.
         
     | 
| 
      
 887 
     | 
    
         
            +
                token_steps: List[List[int]] = None
         
     | 
| 
       868 
888 
     | 
    
         | 
| 
       869 
889 
     | 
    
         | 
| 
       870 
890 
     | 
    
         
             
            @dataclass
         
     | 
| 
         @@ -881,6 +901,7 @@ class BatchStrOutput(BaseBatchReq): 
     | 
|
| 
       881 
901 
     | 
    
         
             
                completion_tokens: List[int]
         
     | 
| 
       882 
902 
     | 
    
         
             
                cached_tokens: List[int]
         
     | 
| 
       883 
903 
     | 
    
         
             
                spec_verify_ct: List[int]
         
     | 
| 
      
 904 
     | 
    
         
            +
                spec_accepted_tokens: List[int]
         
     | 
| 
       884 
905 
     | 
    
         | 
| 
       885 
906 
     | 
    
         
             
                # Logprobs
         
     | 
| 
       886 
907 
     | 
    
         
             
                input_token_logprobs_val: List[float]
         
     | 
| 
         @@ -895,13 +916,20 @@ class BatchStrOutput(BaseBatchReq): 
     | 
|
| 
       895 
916 
     | 
    
         
             
                input_token_ids_logprobs_idx: List[List]
         
     | 
| 
       896 
917 
     | 
    
         
             
                output_token_ids_logprobs_val: List[List]
         
     | 
| 
       897 
918 
     | 
    
         
             
                output_token_ids_logprobs_idx: List[List]
         
     | 
| 
      
 919 
     | 
    
         
            +
                output_token_entropy_val: List[float]
         
     | 
| 
       898 
920 
     | 
    
         | 
| 
       899 
921 
     | 
    
         
             
                # Hidden states
         
     | 
| 
       900 
922 
     | 
    
         
             
                output_hidden_states: List[List[float]]
         
     | 
| 
       901 
923 
     | 
    
         | 
| 
      
 924 
     | 
    
         
            +
                # The information of placeholder tokens (e.g., image token)
         
     | 
| 
      
 925 
     | 
    
         
            +
                # idx is the index of the token in the prompt after expansion.
         
     | 
| 
      
 926 
     | 
    
         
            +
                # val is the length of padded tokens after expansion.
         
     | 
| 
       902 
927 
     | 
    
         
             
                placeholder_tokens_idx: List[Optional[List[int]]]
         
     | 
| 
       903 
928 
     | 
    
         
             
                placeholder_tokens_val: List[Optional[List[int]]]
         
     | 
| 
       904 
929 
     | 
    
         | 
| 
      
 930 
     | 
    
         
            +
                # The trainer step id. Used to know which step's weights are used for sampling.
         
     | 
| 
      
 931 
     | 
    
         
            +
                token_steps: List[List[int]] = None
         
     | 
| 
      
 932 
     | 
    
         
            +
             
     | 
| 
       905 
933 
     | 
    
         | 
| 
       906 
934 
     | 
    
         
             
            @dataclass
         
     | 
| 
       907 
935 
     | 
    
         
             
            class BatchMultimodalOutput(BaseBatchReq):
         
     | 
| 
         @@ -933,7 +961,7 @@ class BatchEmbeddingOutput(BaseBatchReq): 
     | 
|
| 
       933 
961 
     | 
    
         
             
                # The finish reason
         
     | 
| 
       934 
962 
     | 
    
         
             
                finished_reasons: List[BaseFinishReason]
         
     | 
| 
       935 
963 
     | 
    
         
             
                # The output embedding
         
     | 
| 
       936 
     | 
    
         
            -
                embeddings: List[List[float]]
         
     | 
| 
      
 964 
     | 
    
         
            +
                embeddings: Union[List[List[float]], List[Dict[int, float]]]
         
     | 
| 
       937 
965 
     | 
    
         
             
                # Token counts
         
     | 
| 
       938 
966 
     | 
    
         
             
                prompt_tokens: List[int]
         
     | 
| 
       939 
967 
     | 
    
         
             
                cached_tokens: List[int]
         
     | 
| 
         @@ -978,6 +1006,8 @@ class UpdateWeightFromDiskReqInput(BaseReq): 
     | 
|
| 
       978 
1006 
     | 
    
         
             
                torch_empty_cache: bool = False
         
     | 
| 
       979 
1007 
     | 
    
         
             
                # Whether to keep the scheduler paused after weight update
         
     | 
| 
       980 
1008 
     | 
    
         
             
                keep_pause: bool = False
         
     | 
| 
      
 1009 
     | 
    
         
            +
                # The trainer step id. Used to know which step's weights are used for sampling.
         
     | 
| 
      
 1010 
     | 
    
         
            +
                token_step: int = 0
         
     | 
| 
       981 
1011 
     | 
    
         | 
| 
       982 
1012 
     | 
    
         | 
| 
       983 
1013 
     | 
    
         
             
            @dataclass
         
     | 
| 
         @@ -1050,6 +1080,24 @@ class InitWeightsSendGroupForRemoteInstanceReqInput(BaseReq): 
     | 
|
| 
       1050 
1080 
     | 
    
         
             
                backend: str = "nccl"
         
     | 
| 
       1051 
1081 
     | 
    
         | 
| 
       1052 
1082 
     | 
    
         | 
| 
      
 1083 
     | 
    
         
            +
            # Now UpdateWeightsFromIPCReqInput and UpdateWeightsFromIPCReqOutput
         
     | 
| 
      
 1084 
     | 
    
         
            +
            # are only used by Checkpoint Engine (https://github.com/MoonshotAI/checkpoint-engine)
         
     | 
| 
      
 1085 
     | 
    
         
            +
            @dataclass
         
     | 
| 
      
 1086 
     | 
    
         
            +
            class UpdateWeightsFromIPCReqInput(BaseReq):
         
     | 
| 
      
 1087 
     | 
    
         
            +
                # ZMQ socket paths for each device UUID
         
     | 
| 
      
 1088 
     | 
    
         
            +
                zmq_handles: Dict[str, str]
         
     | 
| 
      
 1089 
     | 
    
         
            +
                # Whether to flush cache after weight update
         
     | 
| 
      
 1090 
     | 
    
         
            +
                flush_cache: bool = True
         
     | 
| 
      
 1091 
     | 
    
         
            +
                # Optional: Update weight version along with weights
         
     | 
| 
      
 1092 
     | 
    
         
            +
                weight_version: Optional[str] = None
         
     | 
| 
      
 1093 
     | 
    
         
            +
             
     | 
| 
      
 1094 
     | 
    
         
            +
             
     | 
| 
      
 1095 
     | 
    
         
            +
            @dataclass
         
     | 
| 
      
 1096 
     | 
    
         
            +
            class UpdateWeightsFromIPCReqOutput(BaseReq):
         
     | 
| 
      
 1097 
     | 
    
         
            +
                success: bool
         
     | 
| 
      
 1098 
     | 
    
         
            +
                message: str
         
     | 
| 
      
 1099 
     | 
    
         
            +
             
     | 
| 
      
 1100 
     | 
    
         
            +
             
     | 
| 
       1053 
1101 
     | 
    
         
             
            @dataclass
         
     | 
| 
       1054 
1102 
     | 
    
         
             
            class InitWeightsSendGroupForRemoteInstanceReqOutput(BaseReq):
         
     | 
| 
       1055 
1103 
     | 
    
         
             
                success: bool
         
     | 
| 
         @@ -1206,6 +1254,8 @@ class ProfileReqInput(BaseReq): 
     | 
|
| 
       1206 
1254 
     | 
    
         
             
                profile_by_stage: bool = False
         
     | 
| 
       1207 
1255 
     | 
    
         
             
                with_stack: Optional[bool] = None
         
     | 
| 
       1208 
1256 
     | 
    
         
             
                record_shapes: Optional[bool] = None
         
     | 
| 
      
 1257 
     | 
    
         
            +
                # Merge profiles from all ranks into a single trace
         
     | 
| 
      
 1258 
     | 
    
         
            +
                merge_profiles: bool = False
         
     | 
| 
       1209 
1259 
     | 
    
         | 
| 
       1210 
1260 
     | 
    
         | 
| 
       1211 
1261 
     | 
    
         
             
            class ProfileReqType(Enum):
         
     | 
| 
         @@ -1224,6 +1274,8 @@ class ProfileReq(BaseReq): 
     | 
|
| 
       1224 
1274 
     | 
    
         
             
                with_stack: Optional[bool] = None
         
     | 
| 
       1225 
1275 
     | 
    
         
             
                record_shapes: Optional[bool] = None
         
     | 
| 
       1226 
1276 
     | 
    
         
             
                profile_id: Optional[str] = None
         
     | 
| 
      
 1277 
     | 
    
         
            +
                # Merge profiles from all ranks into a single trace
         
     | 
| 
      
 1278 
     | 
    
         
            +
                merge_profiles: bool = False
         
     | 
| 
       1227 
1279 
     | 
    
         | 
| 
       1228 
1280 
     | 
    
         | 
| 
       1229 
1281 
     | 
    
         
             
            @dataclass
         
     | 
| 
         @@ -1375,18 +1427,6 @@ class LoRAUpdateOutput(BaseReq): 
     | 
|
| 
       1375 
1427 
     | 
    
         
             
            LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateOutput
         
     | 
| 
       1376 
1428 
     | 
    
         | 
| 
       1377 
1429 
     | 
    
         | 
| 
       1378 
     | 
    
         
            -
            @dataclass
         
     | 
| 
       1379 
     | 
    
         
            -
            class MultiTokenizerRegisterReq(BaseBatchReq):
         
     | 
| 
       1380 
     | 
    
         
            -
                ipc_name: Optional[str] = None
         
     | 
| 
       1381 
     | 
    
         
            -
             
     | 
| 
       1382 
     | 
    
         
            -
             
     | 
| 
       1383 
     | 
    
         
            -
            @dataclass
         
     | 
| 
       1384 
     | 
    
         
            -
            class MultiTokenizerWrapper:
         
     | 
| 
       1385 
     | 
    
         
            -
                # FIXME(lsyin): remove this
         
     | 
| 
       1386 
     | 
    
         
            -
                worker_id: int
         
     | 
| 
       1387 
     | 
    
         
            -
                obj: Optional[Any] = None
         
     | 
| 
       1388 
     | 
    
         
            -
             
     | 
| 
       1389 
     | 
    
         
            -
             
     | 
| 
       1390 
1430 
     | 
    
         
             
            class BlockReqType(Enum):
         
     | 
| 
       1391 
1431 
     | 
    
         
             
                BLOCK = 1
         
     | 
| 
       1392 
1432 
     | 
    
         
             
                UNBLOCK = 2
         
     | 
| 
         @@ -1415,6 +1455,16 @@ class WatchLoadUpdateReq(BaseReq): 
     | 
|
| 
       1415 
1455 
     | 
    
         
             
                loads: List[GetLoadReqOutput]
         
     | 
| 
       1416 
1456 
     | 
    
         | 
| 
       1417 
1457 
     | 
    
         | 
| 
      
 1458 
     | 
    
         
            +
            @dataclass
         
     | 
| 
      
 1459 
     | 
    
         
            +
            class LazyDumpTensorsReqInput(BaseReq):
         
     | 
| 
      
 1460 
     | 
    
         
            +
                pass
         
     | 
| 
      
 1461 
     | 
    
         
            +
             
     | 
| 
      
 1462 
     | 
    
         
            +
             
     | 
| 
      
 1463 
     | 
    
         
            +
            @dataclass
         
     | 
| 
      
 1464 
     | 
    
         
            +
            class LazyDumpTensorsReqOutput(BaseReq):
         
     | 
| 
      
 1465 
     | 
    
         
            +
                success: bool
         
     | 
| 
      
 1466 
     | 
    
         
            +
             
     | 
| 
      
 1467 
     | 
    
         
            +
             
     | 
| 
       1418 
1468 
     | 
    
         
             
            def _check_all_req_types():
         
     | 
| 
       1419 
1469 
     | 
    
         
             
                """A helper function to check all request types are defined in this file."""
         
     | 
| 
       1420 
1470 
     | 
    
         
             
                import inspect
         
     | 
    
        sglang/srt/managers/mm_utils.py
    CHANGED
    
    | 
         @@ -16,10 +16,10 @@ from sglang.srt.managers.schedule_batch import ( 
     | 
|
| 
       16 
16 
     | 
    
         
             
                Modality,
         
     | 
| 
       17 
17 
     | 
    
         
             
                MultimodalDataItem,
         
     | 
| 
       18 
18 
     | 
    
         
             
                MultimodalInputs,
         
     | 
| 
       19 
     | 
    
         
            -
                global_server_args_dict,
         
     | 
| 
       20 
19 
     | 
    
         
             
            )
         
     | 
| 
       21 
20 
     | 
    
         
             
            from sglang.srt.mem_cache.multimodal_cache import MultiModalCache
         
     | 
| 
       22 
21 
     | 
    
         
             
            from sglang.srt.model_executor.forward_batch_info import ForwardBatch
         
     | 
| 
      
 22 
     | 
    
         
            +
            from sglang.srt.server_args import get_global_server_args
         
     | 
| 
       23 
23 
     | 
    
         
             
            from sglang.srt.utils import flatten_nested_list, is_npu, print_warning_once
         
     | 
| 
       24 
24 
     | 
    
         
             
            from sglang.utils import logger
         
     | 
| 
       25 
25 
     | 
    
         | 
| 
         @@ -280,7 +280,6 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa 
     | 
|
| 
       280 
280 
     | 
    
         
             
                        input_ids_tensor[input_ids_tensor == token_id] = pad_value
         
     | 
| 
       281 
281 
     | 
    
         | 
| 
       282 
282 
     | 
    
         
             
                    ret_input_ids = input_ids_tensor.tolist()
         
     | 
| 
       283 
     | 
    
         
            -
             
     | 
| 
       284 
283 
     | 
    
         
             
                    return ret_input_ids
         
     | 
| 
       285 
284 
     | 
    
         | 
| 
       286 
285 
     | 
    
         | 
| 
         @@ -428,7 +427,7 @@ def _adjust_embedding_length( 
     | 
|
| 
       428 
427 
     | 
    
         
             
                        f"tokens from multimodal embeddings."
         
     | 
| 
       429 
428 
     | 
    
         
             
                    )
         
     | 
| 
       430 
429 
     | 
    
         
             
                    if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding:
         
     | 
| 
       431 
     | 
    
         
            -
                        chunked_prefill_size =  
     | 
| 
      
 430 
     | 
    
         
            +
                        chunked_prefill_size = get_global_server_args().chunked_prefill_size
         
     | 
| 
       432 
431 
     | 
    
         
             
                        if chunked_prefill_size != -1:
         
     | 
| 
       433 
432 
     | 
    
         
             
                            logger.warning(
         
     | 
| 
       434 
433 
     | 
    
         
             
                                "You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked prefill"
         
     | 
| 
         @@ -507,7 +506,7 @@ def embed_mm_inputs( 
     | 
|
| 
       507 
506 
     | 
    
         
             
                    Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
         
     | 
| 
       508 
507 
     | 
    
         
             
                ] = None,
         
     | 
| 
       509 
508 
     | 
    
         
             
                placeholder_tokens: dict[Modality, List[int]] = None,
         
     | 
| 
       510 
     | 
    
         
            -
                use_deepstack: bool =  
     | 
| 
      
 509 
     | 
    
         
            +
                use_deepstack: Dict[Modality, bool] = {},
         
     | 
| 
       511 
510 
     | 
    
         
             
            ) -> Optional[torch.Tensor]:
         
     | 
| 
       512 
511 
     | 
    
         
             
                """
         
     | 
| 
       513 
512 
     | 
    
         
             
                Embed multimodal inputs and integrate them with text token embeddings.
         
     | 
| 
         @@ -533,7 +532,9 @@ def embed_mm_inputs( 
     | 
|
| 
       533 
532 
     | 
    
         
             
                for mm_inputs in mm_inputs_list:
         
     | 
| 
       534 
533 
     | 
    
         
             
                    item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
         
     | 
| 
       535 
534 
     | 
    
         | 
| 
       536 
     | 
    
         
            -
                 
     | 
| 
      
 535 
     | 
    
         
            +
                # deepstack_embeddings: per-modality
         
     | 
| 
      
 536 
     | 
    
         
            +
                modalities, embeddings, masks, deepstack_embeddings = [], [], [], []
         
     | 
| 
      
 537 
     | 
    
         
            +
             
     | 
| 
       537 
538 
     | 
    
         
             
                # 2. Get multimodal embedding separately
         
     | 
| 
       538 
539 
     | 
    
         
             
                # Try get mm embedding if any
         
     | 
| 
       539 
540 
     | 
    
         
             
                for modality in Modality.all():
         
     | 
| 
         @@ -549,7 +550,8 @@ def embed_mm_inputs( 
     | 
|
| 
       549 
550 
     | 
    
         
             
                        # "image", "video", etc
         
     | 
| 
       550 
551 
     | 
    
         
             
                        modality_id = modality.name.lower()
         
     | 
| 
       551 
552 
     | 
    
         
             
                        embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
         
     | 
| 
       552 
     | 
    
         
            -
                    if len(items) != 0 
     | 
| 
      
 553 
     | 
    
         
            +
                    if len(items) != 0:
         
     | 
| 
      
 554 
     | 
    
         
            +
                        assert embedder is not None, f"no embedding method found for {modality}"
         
     | 
| 
       553 
555 
     | 
    
         
             
                        placeholder_tensor = torch.as_tensor(
         
     | 
| 
       554 
556 
     | 
    
         
             
                            [item.pad_value for item in items],
         
     | 
| 
       555 
557 
     | 
    
         
             
                            device=input_ids.device,
         
     | 
| 
         @@ -580,11 +582,12 @@ def embed_mm_inputs( 
     | 
|
| 
       580 
582 
     | 
    
         
             
                            items_offset_list=items_offsets,
         
     | 
| 
       581 
583 
     | 
    
         
             
                        )
         
     | 
| 
       582 
584 
     | 
    
         | 
| 
       583 
     | 
    
         
            -
                        if use_deepstack and embedding is not None:
         
     | 
| 
      
 585 
     | 
    
         
            +
                        if use_deepstack.get(modality, None) and embedding is not None:
         
     | 
| 
       584 
586 
     | 
    
         
             
                            embedding, deepstack_embedding = (
         
     | 
| 
       585 
587 
     | 
    
         
             
                                multimodal_model.separate_deepstack_embeds(embedding)
         
     | 
| 
       586 
588 
     | 
    
         
             
                            )
         
     | 
| 
       587 
589 
     | 
    
         
             
                            deepstack_embeddings += [deepstack_embedding]
         
     | 
| 
      
 590 
     | 
    
         
            +
                        modalities += [modality]
         
     | 
| 
       588 
591 
     | 
    
         
             
                        embeddings += [embedding]
         
     | 
| 
       589 
592 
     | 
    
         
             
                        masks += [mask]
         
     | 
| 
       590 
593 
     | 
    
         | 
| 
         @@ -597,17 +600,14 @@ def embed_mm_inputs( 
     | 
|
| 
       597 
600 
     | 
    
         
             
                input_ids.clamp_(min=0, max=vocab_size - 1)
         
     | 
| 
       598 
601 
     | 
    
         
             
                inputs_embeds = input_embedding(input_ids)
         
     | 
| 
       599 
602 
     | 
    
         | 
| 
       600 
     | 
    
         
            -
                # 4. scatter embeddings into input embedding
         
     | 
| 
       601 
     | 
    
         
            -
             
     | 
| 
       602 
603 
     | 
    
         
             
                # deepstack embedding
         
     | 
| 
       603 
604 
     | 
    
         
             
                if use_deepstack:
         
     | 
| 
       604 
     | 
    
         
            -
                    num_deepstack_embeddings = (
         
     | 
| 
       605 
     | 
    
         
            -
             
     | 
| 
       606 
     | 
    
         
            -
                    )
         
     | 
| 
      
 605 
     | 
    
         
            +
                    num_deepstack_embeddings = len(multimodal_model.deepstack_visual_indexes)
         
     | 
| 
      
 606 
     | 
    
         
            +
             
     | 
| 
       607 
607 
     | 
    
         
             
                    deepstack_embedding_shape = inputs_embeds.shape[:-1] + (
         
     | 
| 
       608 
608 
     | 
    
         
             
                        inputs_embeds.shape[-1] * num_deepstack_embeddings,
         
     | 
| 
       609 
609 
     | 
    
         
             
                    )
         
     | 
| 
       610 
     | 
    
         
            -
             
     | 
| 
      
 610 
     | 
    
         
            +
                    # a zero-filled embedding, with the same length of inputs_embeds, but different hidden_size
         
     | 
| 
       611 
611 
     | 
    
         
             
                    input_deepstack_embeds = torch.zeros(
         
     | 
| 
       612 
612 
     | 
    
         
             
                        deepstack_embedding_shape,
         
     | 
| 
       613 
613 
     | 
    
         
             
                        device=inputs_embeds.device,
         
     | 
| 
         @@ -616,14 +616,16 @@ def embed_mm_inputs( 
     | 
|
| 
       616 
616 
     | 
    
         | 
| 
       617 
617 
     | 
    
         
             
                    other_info["input_deepstack_embeds"] = input_deepstack_embeds
         
     | 
| 
       618 
618 
     | 
    
         | 
| 
       619 
     | 
    
         
            -
                 
     | 
| 
      
 619 
     | 
    
         
            +
                # 4. scatter embeddings into input embedding
         
     | 
| 
      
 620 
     | 
    
         
            +
                for i, modality, embedding, mask in zip(
         
     | 
| 
      
 621 
     | 
    
         
            +
                    range(len(embeddings)), modalities, embeddings, masks
         
     | 
| 
      
 622 
     | 
    
         
            +
                ):
         
     | 
| 
       620 
623 
     | 
    
         
             
                    if embedding is None or mask is None:
         
     | 
| 
       621 
624 
     | 
    
         
             
                        continue
         
     | 
| 
       622 
625 
     | 
    
         
             
                    # in-place update
         
     | 
| 
       623 
626 
     | 
    
         
             
                    indices = torch.where(mask.squeeze(dim=-1))[0]
         
     | 
| 
       624 
627 
     | 
    
         
             
                    inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
         
     | 
| 
       625 
     | 
    
         
            -
             
     | 
| 
       626 
     | 
    
         
            -
                    if use_deepstack:
         
     | 
| 
      
 628 
     | 
    
         
            +
                    if use_deepstack.get(modality, None):
         
     | 
| 
       627 
629 
     | 
    
         
             
                        input_deepstack_embeds[indices] = deepstack_embeddings[i].to(
         
     | 
| 
       628 
630 
     | 
    
         
             
                            inputs_embeds.device, inputs_embeds.dtype
         
     | 
| 
       629 
631 
     | 
    
         
             
                        )
         
     | 
| 
         @@ -640,7 +642,7 @@ def general_mm_embed_routine( 
     | 
|
| 
       640 
642 
     | 
    
         
             
                    Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
         
     | 
| 
       641 
643 
     | 
    
         
             
                ] = None,
         
     | 
| 
       642 
644 
     | 
    
         
             
                placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
         
     | 
| 
       643 
     | 
    
         
            -
                use_deepstack: bool =  
     | 
| 
      
 645 
     | 
    
         
            +
                use_deepstack: Dict[Modality, bool] = {},
         
     | 
| 
       644 
646 
     | 
    
         
             
                **kwargs,
         
     | 
| 
       645 
647 
     | 
    
         
             
            ) -> torch.Tensor:
         
     | 
| 
       646 
648 
     | 
    
         
             
                """
         
     | 
| 
         @@ -652,7 +654,7 @@ def general_mm_embed_routine( 
     | 
|
| 
       652 
654 
     | 
    
         
             
                    language_model: Base language model to use
         
     | 
| 
       653 
655 
     | 
    
         
             
                    data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
         
     | 
| 
       654 
656 
     | 
    
         
             
                    placeholder_tokens: Token IDs for multimodal placeholders
         
     | 
| 
       655 
     | 
    
         
            -
                    use_deepstack: Whether to use deepstack embeddings
         
     | 
| 
      
 657 
     | 
    
         
            +
                    use_deepstack: Whether to use deepstack embeddings for each modality, default False
         
     | 
| 
       656 
658 
     | 
    
         
             
                    **kwargs: Additional arguments passed to language model
         
     | 
| 
       657 
659 
     | 
    
         | 
| 
       658 
660 
     | 
    
         
             
                Returns:
         
     | 
| 
         @@ -1,3 +1,5 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            from __future__ import annotations
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
       1 
3 
     | 
    
         
             
            # Copyright 2023-2024 SGLang Team
         
     | 
| 
       2 
4 
     | 
    
         
             
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 
       3 
5 
     | 
    
         
             
            # you may not use this file except in compliance with the License.
         
     | 
| 
         @@ -21,7 +23,7 @@ import sys 
     | 
|
| 
       21 
23 
     | 
    
         
             
            import threading
         
     | 
| 
       22 
24 
     | 
    
         
             
            from functools import partialmethod
         
     | 
| 
       23 
25 
     | 
    
         
             
            from multiprocessing import shared_memory
         
     | 
| 
       24 
     | 
    
         
            -
            from typing import Any, Dict
         
     | 
| 
      
 26 
     | 
    
         
            +
            from typing import TYPE_CHECKING, Any, Dict, Union
         
     | 
| 
       25 
27 
     | 
    
         | 
| 
       26 
28 
     | 
    
         
             
            import setproctitle
         
     | 
| 
       27 
29 
     | 
    
         
             
            import zmq
         
     | 
| 
         @@ -30,12 +32,12 @@ import zmq.asyncio 
     | 
|
| 
       30 
32 
     | 
    
         
             
            from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend
         
     | 
| 
       31 
33 
     | 
    
         
             
            from sglang.srt.managers.disagg_service import start_disagg_service
         
     | 
| 
       32 
34 
     | 
    
         
             
            from sglang.srt.managers.io_struct import (
         
     | 
| 
      
 35 
     | 
    
         
            +
                BaseBatchReq,
         
     | 
| 
      
 36 
     | 
    
         
            +
                BaseReq,
         
     | 
| 
       33 
37 
     | 
    
         
             
                BatchEmbeddingOutput,
         
     | 
| 
       34 
38 
     | 
    
         
             
                BatchMultimodalOutput,
         
     | 
| 
       35 
39 
     | 
    
         
             
                BatchStrOutput,
         
     | 
| 
       36 
40 
     | 
    
         
             
                BatchTokenIDOutput,
         
     | 
| 
       37 
     | 
    
         
            -
                MultiTokenizerRegisterReq,
         
     | 
| 
       38 
     | 
    
         
            -
                MultiTokenizerWrapper,
         
     | 
| 
       39 
41 
     | 
    
         
             
            )
         
     | 
| 
       40 
42 
     | 
    
         
             
            from sglang.srt.managers.tokenizer_communicator_mixin import _Communicator
         
     | 
| 
       41 
43 
     | 
    
         
             
            from sglang.srt.managers.tokenizer_manager import TokenizerManager
         
     | 
| 
         @@ -43,6 +45,9 @@ from sglang.srt.server_args import PortArgs, ServerArgs 
     | 
|
| 
       43 
45 
     | 
    
         
             
            from sglang.srt.utils import get_zmq_socket, kill_process_tree
         
     | 
| 
       44 
46 
     | 
    
         
             
            from sglang.utils import get_exception_traceback
         
     | 
| 
       45 
47 
     | 
    
         | 
| 
      
 48 
     | 
    
         
            +
            if TYPE_CHECKING:
         
     | 
| 
      
 49 
     | 
    
         
            +
                from sglang.srt.managers.detokenizer_manager import DetokenizerManager
         
     | 
| 
      
 50 
     | 
    
         
            +
             
     | 
| 
       46 
51 
     | 
    
         
             
            logger = logging.getLogger(__name__)
         
     | 
| 
       47 
52 
     | 
    
         | 
| 
       48 
53 
     | 
    
         | 
| 
         @@ -56,29 +61,24 @@ class SocketMapping: 
     | 
|
| 
       56 
61 
     | 
    
         
             
                        socket.close()
         
     | 
| 
       57 
62 
     | 
    
         
             
                    self._mapping.clear()
         
     | 
| 
       58 
63 
     | 
    
         | 
| 
       59 
     | 
    
         
            -
                def  
     | 
| 
       60 
     | 
    
         
            -
                    self, recv_obj: MultiTokenizerRegisterReq, worker_id: str, is_tokenizer: bool
         
     | 
| 
       61 
     | 
    
         
            -
                ):
         
     | 
| 
      
 64 
     | 
    
         
            +
                def _register_ipc_mapping(self, ipc_name: str, is_tokenizer: bool):
         
     | 
| 
       62 
65 
     | 
    
         
             
                    type_str = "tokenizer" if is_tokenizer else "detokenizer"
         
     | 
| 
       63 
     | 
    
         
            -
                    if  
     | 
| 
       64 
     | 
    
         
            -
                        logger.warning(
         
     | 
| 
       65 
     | 
    
         
            -
                            f"{type_str} already registered with worker {worker_id}, skipping..."
         
     | 
| 
       66 
     | 
    
         
            -
                        )
         
     | 
| 
      
 66 
     | 
    
         
            +
                    if ipc_name in self._mapping:
         
     | 
| 
      
 67 
     | 
    
         
            +
                        logger.warning(f"{type_str} already registered {ipc_name=}, skipping...")
         
     | 
| 
       67 
68 
     | 
    
         
             
                        return
         
     | 
| 
       68 
     | 
    
         
            -
                    logger.info(
         
     | 
| 
       69 
     | 
    
         
            -
             
     | 
| 
       70 
     | 
    
         
            -
                     
     | 
| 
       71 
     | 
    
         
            -
             
     | 
| 
       72 
     | 
    
         
            -
             
     | 
| 
       73 
     | 
    
         
            -
                     
     | 
| 
       74 
     | 
    
         
            -
             
     | 
| 
       75 
     | 
    
         
            -
             
     | 
| 
       76 
     | 
    
         
            -
                    if worker_id not in self._mapping:
         
     | 
| 
       77 
     | 
    
         
            -
                        logger.error(
         
     | 
| 
       78 
     | 
    
         
            -
                            f"worker ID {worker_id} not registered. Check if the server Process is alive"
         
     | 
| 
       79 
     | 
    
         
            -
                        )
         
     | 
| 
      
 69 
     | 
    
         
            +
                    logger.info(f"Registering {type_str} {ipc_name=} in SocketMapping...")
         
     | 
| 
      
 70 
     | 
    
         
            +
                    socket = get_zmq_socket(self._zmq_context, zmq.PUSH, ipc_name, False)
         
     | 
| 
      
 71 
     | 
    
         
            +
                    self._mapping[ipc_name] = socket
         
     | 
| 
      
 72 
     | 
    
         
            +
             
     | 
| 
      
 73 
     | 
    
         
            +
                def send_output(self, ipc_name: str, output: Any):
         
     | 
| 
      
 74 
     | 
    
         
            +
                    if ipc_name is None:
         
     | 
| 
      
 75 
     | 
    
         
            +
                        # Some unhandled cases
         
     | 
| 
      
 76 
     | 
    
         
            +
                        logger.warning(f"IPC name is None, output type={type(output)}, skipping...")
         
     | 
| 
       80 
77 
     | 
    
         
             
                        return
         
     | 
| 
       81 
     | 
    
         
            -
             
     | 
| 
      
 78 
     | 
    
         
            +
             
     | 
| 
      
 79 
     | 
    
         
            +
                    if ipc_name not in self._mapping:
         
     | 
| 
      
 80 
     | 
    
         
            +
                        self._register_ipc_mapping(ipc_name, is_tokenizer=False)
         
     | 
| 
      
 81 
     | 
    
         
            +
                    self._mapping[ipc_name].send_pyobj(output)
         
     | 
| 
       82 
82 
     | 
    
         | 
| 
       83 
83 
     | 
    
         | 
| 
       84 
84 
     | 
    
         
             
            def _handle_output_by_index(output, i):
         
     | 
| 
         @@ -190,6 +190,11 @@ def _handle_output_by_index(output, i): 
     | 
|
| 
       190 
190 
     | 
    
         
             
                            if output.output_token_ids_logprobs_idx
         
     | 
| 
       191 
191 
     | 
    
         
             
                            else None
         
     | 
| 
       192 
192 
     | 
    
         
             
                        ),
         
     | 
| 
      
 193 
     | 
    
         
            +
                        output_token_entropy_val=(
         
     | 
| 
      
 194 
     | 
    
         
            +
                            [output.output_token_entropy_val[i]]
         
     | 
| 
      
 195 
     | 
    
         
            +
                            if output.output_token_entropy_val
         
     | 
| 
      
 196 
     | 
    
         
            +
                            else None
         
     | 
| 
      
 197 
     | 
    
         
            +
                        ),
         
     | 
| 
       193 
198 
     | 
    
         
             
                        output_hidden_states=(
         
     | 
| 
       194 
199 
     | 
    
         
             
                            [output.output_hidden_states[i]]
         
     | 
| 
       195 
200 
     | 
    
         
             
                            if output.output_hidden_states
         
     | 
| 
         @@ -197,6 +202,7 @@ def _handle_output_by_index(output, i): 
     | 
|
| 
       197 
202 
     | 
    
         
             
                        ),
         
     | 
| 
       198 
203 
     | 
    
         
             
                        placeholder_tokens_idx=None,
         
     | 
| 
       199 
204 
     | 
    
         
             
                        placeholder_tokens_val=None,
         
     | 
| 
      
 205 
     | 
    
         
            +
                        token_steps=([output.token_steps[i]] if output.token_steps else None),
         
     | 
| 
       200 
206 
     | 
    
         
             
                    )
         
     | 
| 
       201 
207 
     | 
    
         
             
                elif isinstance(output, BatchEmbeddingOutput):
         
     | 
| 
       202 
208 
     | 
    
         
             
                    new_output = BatchEmbeddingOutput(
         
     | 
| 
         @@ -246,6 +252,11 @@ def _handle_output_by_index(output, i): 
     | 
|
| 
       246 
252 
     | 
    
         
             
                        spec_verify_ct=(
         
     | 
| 
       247 
253 
     | 
    
         
             
                            [output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
         
     | 
| 
       248 
254 
     | 
    
         
             
                        ),
         
     | 
| 
      
 255 
     | 
    
         
            +
                        spec_accepted_tokens=(
         
     | 
| 
      
 256 
     | 
    
         
            +
                            [output.spec_accepted_tokens[i]]
         
     | 
| 
      
 257 
     | 
    
         
            +
                            if len(output.spec_accepted_tokens) > i
         
     | 
| 
      
 258 
     | 
    
         
            +
                            else None
         
     | 
| 
      
 259 
     | 
    
         
            +
                        ),
         
     | 
| 
       249 
260 
     | 
    
         
             
                        input_token_logprobs_val=(
         
     | 
| 
       250 
261 
     | 
    
         
             
                            [output.input_token_logprobs_val[i]]
         
     | 
| 
       251 
262 
     | 
    
         
             
                            if output.input_token_logprobs_val
         
     | 
| 
         @@ -306,6 +317,11 @@ def _handle_output_by_index(output, i): 
     | 
|
| 
       306 
317 
     | 
    
         
             
                            if output.output_token_ids_logprobs_idx
         
     | 
| 
       307 
318 
     | 
    
         
             
                            else None
         
     | 
| 
       308 
319 
     | 
    
         
             
                        ),
         
     | 
| 
      
 320 
     | 
    
         
            +
                        output_token_entropy_val=(
         
     | 
| 
      
 321 
     | 
    
         
            +
                            [output.output_token_entropy_val[i]]
         
     | 
| 
      
 322 
     | 
    
         
            +
                            if output.output_token_entropy_val
         
     | 
| 
      
 323 
     | 
    
         
            +
                            else None
         
     | 
| 
      
 324 
     | 
    
         
            +
                        ),
         
     | 
| 
       309 
325 
     | 
    
         
             
                        output_hidden_states=(
         
     | 
| 
       310 
326 
     | 
    
         
             
                            [output.output_hidden_states[i]]
         
     | 
| 
       311 
327 
     | 
    
         
             
                            if output.output_hidden_states
         
     | 
| 
         @@ -313,6 +329,7 @@ def _handle_output_by_index(output, i): 
     | 
|
| 
       313 
329 
     | 
    
         
             
                        ),
         
     | 
| 
       314 
330 
     | 
    
         
             
                        placeholder_tokens_idx=None,
         
     | 
| 
       315 
331 
     | 
    
         
             
                        placeholder_tokens_val=None,
         
     | 
| 
      
 332 
     | 
    
         
            +
                        token_steps=([output.token_steps[i]] if output.token_steps else None),
         
     | 
| 
       316 
333 
     | 
    
         
             
                    )
         
     | 
| 
       317 
334 
     | 
    
         
             
                elif isinstance(output, BatchMultimodalOutput):
         
     | 
| 
       318 
335 
     | 
    
         
             
                    new_output = BatchMultimodalOutput(
         
     | 
| 
         @@ -345,20 +362,11 @@ def _handle_output_by_index(output, i): 
     | 
|
| 
       345 
362 
     | 
    
         
             
            class MultiHttpWorkerDetokenizerMixin:
         
     | 
| 
       346 
363 
     | 
    
         
             
                """Mixin class for DetokenizerManager"""
         
     | 
| 
       347 
364 
     | 
    
         | 
| 
       348 
     | 
    
         
            -
                def  
     | 
| 
       349 
     | 
    
         
            -
                    if isinstance(rids, list):
         
     | 
| 
       350 
     | 
    
         
            -
                        worker_ids = [int(rid.split("_")[0]) for rid in rids]
         
     | 
| 
       351 
     | 
    
         
            -
                    elif isinstance(rids, str):
         
     | 
| 
       352 
     | 
    
         
            -
                        worker_ids = [int(rids.split("_")[0])]
         
     | 
| 
       353 
     | 
    
         
            -
                    else:
         
     | 
| 
       354 
     | 
    
         
            -
                        worker_ids = []
         
     | 
| 
       355 
     | 
    
         
            -
                    return worker_ids
         
     | 
| 
       356 
     | 
    
         
            -
             
     | 
| 
       357 
     | 
    
         
            -
                def maybe_clear_socket_mapping(self):
         
     | 
| 
      
 365 
     | 
    
         
            +
                def maybe_clear_socket_mapping(self: DetokenizerManager):
         
     | 
| 
       358 
366 
     | 
    
         
             
                    if hasattr(self, "socket_mapping"):
         
     | 
| 
       359 
367 
     | 
    
         
             
                        self.socket_mapping.clear_all_sockets()
         
     | 
| 
       360 
368 
     | 
    
         | 
| 
       361 
     | 
    
         
            -
                def multi_http_worker_event_loop(self):
         
     | 
| 
      
 369 
     | 
    
         
            +
                def multi_http_worker_event_loop(self: DetokenizerManager):
         
     | 
| 
       362 
370 
     | 
    
         
             
                    """The event loop that handles requests, for multi multi-http-worker mode"""
         
     | 
| 
       363 
371 
     | 
    
         
             
                    self.socket_mapping = SocketMapping()
         
     | 
| 
       364 
372 
     | 
    
         
             
                    while True:
         
     | 
| 
         @@ -366,23 +374,15 @@ class MultiHttpWorkerDetokenizerMixin: 
     | 
|
| 
       366 
374 
     | 
    
         
             
                        output = self._request_dispatcher(recv_obj)
         
     | 
| 
       367 
375 
     | 
    
         
             
                        if output is None:
         
     | 
| 
       368 
376 
     | 
    
         
             
                            continue
         
     | 
| 
       369 
     | 
    
         
            -
             
     | 
| 
       370 
     | 
    
         
            -
                         
     | 
| 
       371 
     | 
    
         
            -
                             
     | 
| 
       372 
     | 
    
         
            -
                         
     | 
| 
       373 
     | 
    
         
            -
                            raise RuntimeError(
         
     | 
| 
       374 
     | 
    
         
            -
                                f"for tokenizer_worker_num > 1, recv_obj.rids must be a list"
         
     | 
| 
       375 
     | 
    
         
            -
                            )
         
     | 
| 
      
 377 
     | 
    
         
            +
             
     | 
| 
      
 378 
     | 
    
         
            +
                        assert isinstance(
         
     | 
| 
      
 379 
     | 
    
         
            +
                            recv_obj, BaseBatchReq
         
     | 
| 
      
 380 
     | 
    
         
            +
                        ), "for multi-http-worker, recv_obj must be BaseBatchReq"
         
     | 
| 
       376 
381 
     | 
    
         | 
| 
       377 
382 
     | 
    
         
             
                        # Send data using the corresponding socket
         
     | 
| 
       378 
     | 
    
         
            -
                        for i,  
     | 
| 
       379 
     | 
    
         
            -
                             
     | 
| 
       380 
     | 
    
         
            -
             
     | 
| 
       381 
     | 
    
         
            -
                                    recv_obj, worker_id, is_tokenizer=False
         
     | 
| 
       382 
     | 
    
         
            -
                                )
         
     | 
| 
       383 
     | 
    
         
            -
                            else:
         
     | 
| 
       384 
     | 
    
         
            -
                                new_output = _handle_output_by_index(output, i)
         
     | 
| 
       385 
     | 
    
         
            -
                                self.socket_mapping.send_output(worker_id, new_output)
         
     | 
| 
      
 383 
     | 
    
         
            +
                        for i, ipc_name in enumerate(recv_obj.http_worker_ipcs):
         
     | 
| 
      
 384 
     | 
    
         
            +
                            new_output = _handle_output_by_index(output, i)
         
     | 
| 
      
 385 
     | 
    
         
            +
                            self.socket_mapping.send_output(ipc_name, new_output)
         
     | 
| 
       386 
386 
     | 
    
         | 
| 
       387 
387 
     | 
    
         | 
| 
       388 
388 
     | 
    
         
             
            class MultiTokenizerRouter:
         
     | 
| 
         @@ -432,26 +432,17 @@ class MultiTokenizerRouter: 
     | 
|
| 
       432 
432 
     | 
    
         
             
                        await self._distribute_result_to_workers(recv_obj)
         
     | 
| 
       433 
433 
     | 
    
         | 
| 
       434 
434 
     | 
    
         
             
                async def _distribute_result_to_workers(self, recv_obj):
         
     | 
| 
       435 
     | 
    
         
            -
                     
     | 
| 
       436 
     | 
    
         
            -
                    if isinstance(recv_obj,  
     | 
| 
       437 
     | 
    
         
            -
                         
     | 
| 
       438 
     | 
    
         
            -
             
     | 
| 
      
 435 
     | 
    
         
            +
                    # Distribute result to each worker
         
     | 
| 
      
 436 
     | 
    
         
            +
                    if isinstance(recv_obj, BaseReq):
         
     | 
| 
      
 437 
     | 
    
         
            +
                        ipc_names = [recv_obj.http_worker_ipc]
         
     | 
| 
      
 438 
     | 
    
         
            +
                    elif isinstance(recv_obj, BaseBatchReq):
         
     | 
| 
      
 439 
     | 
    
         
            +
                        ipc_names = recv_obj.http_worker_ipcs
         
     | 
| 
       439 
440 
     | 
    
         
             
                    else:
         
     | 
| 
       440 
     | 
    
         
            -
                         
     | 
| 
       441 
     | 
    
         
            -
             
     | 
| 
       442 
     | 
    
         
            -
                    if len(worker_ids) == 0:
         
     | 
| 
       443 
     | 
    
         
            -
                        logger.error(f"Cannot find worker_id from rids {recv_obj.rids}")
         
     | 
| 
       444 
     | 
    
         
            -
                        return
         
     | 
| 
      
 441 
     | 
    
         
            +
                        raise ValueError(f"Unknown recv_obj type: {type(recv_obj)}")
         
     | 
| 
       445 
442 
     | 
    
         | 
| 
       446 
     | 
    
         
            -
                     
     | 
| 
       447 
     | 
    
         
            -
             
     | 
| 
       448 
     | 
    
         
            -
                         
     | 
| 
       449 
     | 
    
         
            -
                            self.socket_mapping.register_ipc_mapping(
         
     | 
| 
       450 
     | 
    
         
            -
                                recv_obj, worker_id, is_tokenizer=True
         
     | 
| 
       451 
     | 
    
         
            -
                            )
         
     | 
| 
       452 
     | 
    
         
            -
                        else:
         
     | 
| 
       453 
     | 
    
         
            -
                            new_recv_obj = _handle_output_by_index(recv_obj, i)
         
     | 
| 
       454 
     | 
    
         
            -
                            self.socket_mapping.send_output(worker_id, new_recv_obj)
         
     | 
| 
      
 443 
     | 
    
         
            +
                    for i, ipc_name in enumerate(ipc_names):
         
     | 
| 
      
 444 
     | 
    
         
            +
                        new_recv_obj = _handle_output_by_index(recv_obj, i)
         
     | 
| 
      
 445 
     | 
    
         
            +
                        self.socket_mapping.send_output(ipc_name, new_recv_obj)
         
     | 
| 
       455 
446 
     | 
    
         | 
| 
       456 
447 
     | 
    
         | 
| 
       457 
448 
     | 
    
         
             
            class TokenizerWorker(TokenizerManager):
         
     | 
| 
         @@ -483,21 +474,15 @@ class TokenizerWorker(TokenizerManager): 
     | 
|
| 
       483 
474 
     | 
    
         
             
                    self.register_multi_tokenizer_communicator = _Communicator(
         
     | 
| 
       484 
475 
     | 
    
         
             
                        self.send_to_scheduler, 2
         
     | 
| 
       485 
476 
     | 
    
         
             
                    )
         
     | 
| 
       486 
     | 
    
         
            -
                    self._result_dispatcher._mapping.append(
         
     | 
| 
       487 
     | 
    
         
            -
                        (
         
     | 
| 
       488 
     | 
    
         
            -
                            MultiTokenizerRegisterReq,
         
     | 
| 
       489 
     | 
    
         
            -
                            self.register_multi_tokenizer_communicator.handle_recv,
         
     | 
| 
       490 
     | 
    
         
            -
                        )
         
     | 
| 
       491 
     | 
    
         
            -
                    )
         
     | 
| 
       492 
477 
     | 
    
         | 
| 
       493 
     | 
    
         
            -
                 
     | 
| 
       494 
     | 
    
         
            -
             
     | 
| 
       495 
     | 
    
         
            -
                     
     | 
| 
       496 
     | 
    
         
            -
             
     | 
| 
       497 
     | 
    
         
            -
                    req  
     | 
| 
       498 
     | 
    
         
            -
             
     | 
| 
       499 
     | 
    
         
            -
                     
     | 
| 
       500 
     | 
    
         
            -
             
     | 
| 
      
 478 
     | 
    
         
            +
                def _attach_multi_http_worker_info(self, req: Union[BaseReq, BaseBatchReq]):
         
     | 
| 
      
 479 
     | 
    
         
            +
             
     | 
| 
      
 480 
     | 
    
         
            +
                    if isinstance(req, BaseReq):
         
     | 
| 
      
 481 
     | 
    
         
            +
                        req.http_worker_ipc = self.tokenizer_ipc_name
         
     | 
| 
      
 482 
     | 
    
         
            +
                    elif isinstance(req, BaseBatchReq):
         
     | 
| 
      
 483 
     | 
    
         
            +
                        req.http_worker_ipcs = [self.tokenizer_ipc_name] * len(req.rids)
         
     | 
| 
      
 484 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 485 
     | 
    
         
            +
                        raise ValueError(f"Unknown req type: {type(req)}")
         
     | 
| 
       501 
486 
     | 
    
         | 
| 
       502 
487 
     | 
    
         | 
| 
       503 
488 
     | 
    
         
             
            async def print_exception_wrapper(func):
         
     |