sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +54 -37
- sglang/bench_one_batch_server.py +340 -34
- sglang/bench_serving.py +340 -159
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +9 -2
- sglang/profiler.py +20 -3
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +309 -0
- sglang/srt/configs/load_config.py +33 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +284 -118
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +576 -0
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +6 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +26 -15
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +268 -98
- sglang/srt/disaggregation/decode.py +172 -39
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +203 -555
- sglang/srt/disaggregation/nixl/conn.py +217 -63
- sglang/srt/disaggregation/prefill.py +113 -270
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +203 -97
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +85 -65
- sglang/srt/entrypoints/grpc_server.py +632 -305
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +169 -17
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +327 -34
- sglang/srt/entrypoints/openai/serving_base.py +74 -8
- sglang/srt/entrypoints/openai/serving_chat.py +202 -118
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +20 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +47 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +323 -0
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +21 -16
- sglang/srt/function_call/glm4_moe_detector.py +4 -8
- sglang/srt/function_call/gpt_oss_detector.py +24 -1
- sglang/srt/function_call/json_array_parser.py +61 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +98 -7
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/grpc_request_manager.py +915 -0
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
- sglang/srt/layers/activation.py +11 -7
- sglang/srt/layers/attention/aiter_backend.py +17 -18
- sglang/srt/layers/attention/ascend_backend.py +125 -10
- sglang/srt/layers/attention/attention_registry.py +226 -0
- sglang/srt/layers/attention/base_attn_backend.py +32 -4
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +52 -15
- sglang/srt/layers/attention/flashinfer_backend.py +357 -212
- sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
- sglang/srt/layers/attention/flashmla_backend.py +9 -7
- sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
- sglang/srt/layers/attention/mamba/mamba.py +514 -1
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +23 -0
- sglang/srt/layers/attention/nsa_backend.py +1201 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +249 -42
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
- sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +61 -3
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +19 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +28 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +47 -15
- sglang/srt/layers/linear.py +30 -5
- sglang/srt/layers/logits_processor.py +161 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
- sglang/srt/layers/moe/ep_moe/layer.py +243 -448
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
- sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +27 -1
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +86 -20
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +43 -15
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +141 -81
- sglang/srt/layers/quantization/mxfp4.py +17 -34
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -24
- sglang/srt/layers/quantization/w8a8_int8.py +45 -27
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +750 -46
- sglang/srt/layers/sampler.py +84 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +23 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +9 -4
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +33 -7
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +41 -17
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +83 -152
- sglang/srt/managers/data_parallel_controller.py +156 -87
- sglang/srt/managers/detokenizer_manager.py +51 -24
- sglang/srt/managers/io_struct.py +223 -129
- sglang/srt/managers/mm_utils.py +49 -10
- sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +130 -0
- sglang/srt/managers/schedule_batch.py +340 -529
- sglang/srt/managers/schedule_policy.py +158 -18
- sglang/srt/managers/scheduler.py +665 -620
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
- sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
- sglang/srt/managers/tokenizer_manager.py +462 -226
- sglang/srt/managers/tp_worker.py +217 -156
- sglang/srt/managers/utils.py +79 -47
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +42 -28
- sglang/srt/mem_cache/base_prefix_cache.py +3 -3
- sglang/srt/mem_cache/chunk_cache.py +20 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +38 -0
- sglang/srt/mem_cache/hicache_storage.py +44 -2
- sglang/srt/mem_cache/hiradix_cache.py +134 -34
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +602 -208
- sglang/srt/mem_cache/memory_pool_host.py +134 -183
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +263 -78
- sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +115 -58
- sglang/srt/metrics/collector.py +113 -120
- sglang/srt/metrics/func_timer.py +3 -8
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +81 -36
- sglang/srt/model_executor/forward_batch_info.py +40 -50
- sglang/srt/model_executor/model_runner.py +507 -319
- sglang/srt/model_executor/npu_graph_runner.py +11 -5
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +438 -37
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +200 -27
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +40 -56
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +25 -4
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +793 -235
- sglang/srt/models/dots_ocr.py +171 -0
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +570 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -3
- sglang/srt/models/glm4_moe.py +17 -40
- sglang/srt/models/glm4_moe_nextn.py +4 -4
- sglang/srt/models/glm4v.py +3 -2
- sglang/srt/models/glm4v_moe.py +6 -6
- sglang/srt/models/gpt_oss.py +12 -35
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +4 -2
- sglang/srt/models/llama.py +6 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +6 -23
- sglang/srt/models/longcat_flash_nextn.py +4 -15
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +27 -6
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +5 -5
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +70 -4
- sglang/srt/models/qwen2_vl.py +6 -3
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +50 -38
- sglang/srt/models/qwen3_next.py +43 -21
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +791 -0
- sglang/srt/models/qwen3_vl_moe.py +343 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +268 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +61 -0
- sglang/srt/multimodal/processors/base_processor.py +21 -9
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +2 -4
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +20 -10
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +83 -17
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +36 -23
- sglang/srt/sampling/sampling_params.py +75 -0
- sglang/srt/server_args.py +1300 -338
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +161 -0
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
- sglang/srt/speculative/eagle_info.py +786 -0
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +113 -1270
- sglang/srt/speculative/eagle_worker.py +120 -285
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/ngram_info.py +433 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +49 -0
- sglang/srt/speculative/spec_utils.py +641 -0
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +35 -18
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/{utils.py → utils/common.py} +583 -113
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +120 -11
- sglang/test/runners.py +3 -1
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +8 -2
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +3 -4
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +430 -0
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +93 -1
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +432 -16
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
- sglang/srt/entrypoints/grpc_request_manager.py +0 -580
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -17,7 +17,11 @@ from typing import List, Optional, Tuple
|
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
19
|
|
|
20
|
-
from .base_grammar_backend import
|
|
20
|
+
from .base_grammar_backend import (
|
|
21
|
+
INVALID_GRAMMAR_OBJ,
|
|
22
|
+
BaseGrammarBackend,
|
|
23
|
+
BaseGrammarObject,
|
|
24
|
+
)
|
|
21
25
|
|
|
22
26
|
|
|
23
27
|
class ReasonerGrammarObject(BaseGrammarObject):
|
|
@@ -81,10 +85,9 @@ class ReasonerGrammarBackend(BaseGrammarBackend):
|
|
|
81
85
|
self.grammar_backend = grammar_backend
|
|
82
86
|
self.think_end_id = think_end_id
|
|
83
87
|
|
|
84
|
-
def _init_value_dispatch(
|
|
85
|
-
self, key: Tuple[str, str]
|
|
86
|
-
) -> Optional[ReasonerGrammarObject]:
|
|
88
|
+
def _init_value_dispatch(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
|
|
87
89
|
ret = self.grammar_backend._init_value_dispatch(key)
|
|
88
|
-
|
|
89
|
-
|
|
90
|
+
# avoid wrapping invalid grammar, so that the scheduler can detect it
|
|
91
|
+
if ret is None or ret is INVALID_GRAMMAR_OBJ:
|
|
92
|
+
return ret
|
|
90
93
|
return ReasonerGrammarObject(ret, self.think_end_id)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from typing import Dict
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def is_legacy_structural_tag(obj: Dict) -> bool:
|
|
5
|
+
# test whether an object is a legacy structural tag
|
|
6
|
+
# see `StructuralTagResponseFormat` at `sglang.srt.entrypoints.openai.protocol`
|
|
7
|
+
if obj.get("structures", None) is not None:
|
|
8
|
+
assert obj.get("triggers", None) is not None
|
|
9
|
+
return True
|
|
10
|
+
else:
|
|
11
|
+
assert obj.get("format", None) is not None
|
|
12
|
+
return False
|
|
@@ -34,6 +34,7 @@ from sglang.srt.constrained.base_grammar_backend import (
|
|
|
34
34
|
BaseGrammarObject,
|
|
35
35
|
GrammarStats,
|
|
36
36
|
)
|
|
37
|
+
from sglang.srt.constrained.utils import is_legacy_structural_tag
|
|
37
38
|
from sglang.srt.utils import is_hip
|
|
38
39
|
|
|
39
40
|
_is_hip = is_hip()
|
|
@@ -167,6 +168,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
|
167
168
|
tokenizer,
|
|
168
169
|
vocab_size: int,
|
|
169
170
|
model_eos_token_ids: Optional[List[int]] = None,
|
|
171
|
+
any_whitespace: bool = True,
|
|
170
172
|
):
|
|
171
173
|
super().__init__()
|
|
172
174
|
|
|
@@ -188,6 +190,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
|
188
190
|
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
|
|
189
191
|
self.vocab_size = vocab_size
|
|
190
192
|
self.override_stop_tokens = override_stop_tokens
|
|
193
|
+
self.any_whitespace = any_whitespace
|
|
191
194
|
|
|
192
195
|
def _from_context(
|
|
193
196
|
self, ctx: CompiledGrammar, key_string: str, grammar_stats: GrammarStats
|
|
@@ -212,12 +215,14 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
|
212
215
|
# Note: This builtin JSON grammar includes *all* valid JSON (including, for example, arrays at the root)
|
|
213
216
|
ctx = self.grammar_compiler.compile_builtin_json_grammar()
|
|
214
217
|
else:
|
|
215
|
-
ctx = self.grammar_compiler.compile_json_schema(
|
|
218
|
+
ctx = self.grammar_compiler.compile_json_schema(
|
|
219
|
+
schema=key_string, any_whitespace=self.any_whitespace
|
|
220
|
+
)
|
|
216
221
|
|
|
217
222
|
except (RuntimeError, json.decoder.JSONDecodeError) as e:
|
|
218
223
|
logging.error(f"Hit invalid json_schema: {key_string=}, {e=}")
|
|
219
224
|
return INVALID_GRAMMAR_OBJ
|
|
220
|
-
return self._from_context(ctx, key_string, GrammarStats())
|
|
225
|
+
return self._from_context(ctx, key_string, GrammarStats(dispatch_type="json"))
|
|
221
226
|
|
|
222
227
|
def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]:
|
|
223
228
|
try:
|
|
@@ -225,7 +230,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
|
225
230
|
except RuntimeError as e:
|
|
226
231
|
logging.error(f"Hit invalid ebnf: {key_string=}, {e=}")
|
|
227
232
|
return INVALID_GRAMMAR_OBJ
|
|
228
|
-
return self._from_context(ctx, key_string, GrammarStats())
|
|
233
|
+
return self._from_context(ctx, key_string, GrammarStats(dispatch_type="ebnf"))
|
|
229
234
|
|
|
230
235
|
def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]:
|
|
231
236
|
try:
|
|
@@ -233,26 +238,32 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
|
233
238
|
except RuntimeError as e:
|
|
234
239
|
logging.error(f"Hit invalid regex: {key_string=}, {e=}")
|
|
235
240
|
return INVALID_GRAMMAR_OBJ
|
|
236
|
-
return self._from_context(ctx, key_string, GrammarStats())
|
|
241
|
+
return self._from_context(ctx, key_string, GrammarStats(dispatch_type="regex"))
|
|
237
242
|
|
|
238
243
|
def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
|
|
239
244
|
try:
|
|
245
|
+
# TODO(dark): it's REALLY stupid to construct object from string and decode it again
|
|
240
246
|
structural_tag = json.loads(key_string)
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
247
|
+
if is_legacy_structural_tag(structural_tag):
|
|
248
|
+
tags = [
|
|
249
|
+
StructuralTagItem(
|
|
250
|
+
begin=structure["begin"],
|
|
251
|
+
schema=json.dumps(structure["schema"]),
|
|
252
|
+
end=structure["end"],
|
|
253
|
+
)
|
|
254
|
+
for structure in structural_tag["structures"]
|
|
255
|
+
]
|
|
256
|
+
ctx = self.grammar_compiler.compile_structural_tag(
|
|
257
|
+
tags, structural_tag["triggers"]
|
|
246
258
|
)
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
ctx = self.grammar_compiler.compile_structural_tag(
|
|
250
|
-
tags, structural_tag["triggers"]
|
|
251
|
-
)
|
|
259
|
+
else:
|
|
260
|
+
ctx = self.grammar_compiler.compile_structural_tag(key_string)
|
|
252
261
|
except (RuntimeError, json.decoder.JSONDecodeError) as e:
|
|
253
262
|
logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}")
|
|
254
263
|
return INVALID_GRAMMAR_OBJ
|
|
255
|
-
return self._from_context(
|
|
264
|
+
return self._from_context(
|
|
265
|
+
ctx, key_string, GrammarStats(dispatch_type="structural_tag")
|
|
266
|
+
)
|
|
256
267
|
|
|
257
268
|
def reset(self):
|
|
258
269
|
self.grammar_compiler.clear_cache()
|
sglang/srt/debug_utils/dumper.py
CHANGED
|
@@ -36,6 +36,15 @@ class _Dumper:
|
|
|
36
36
|
self._forward_pass_id = 0
|
|
37
37
|
|
|
38
38
|
def on_forward_pass_start(self):
|
|
39
|
+
"""This should be called on all ranks."""
|
|
40
|
+
|
|
41
|
+
if not self._enable:
|
|
42
|
+
return
|
|
43
|
+
|
|
44
|
+
# Users may want to `dump` only on some ranks, thus determine name here
|
|
45
|
+
if self._partial_name is None:
|
|
46
|
+
self._partial_name = _get_partial_name()
|
|
47
|
+
|
|
39
48
|
self._forward_pass_id += 1
|
|
40
49
|
print(
|
|
41
50
|
f"[Dumper] [{time.time()}] on_forward_pass_start id={self._forward_pass_id}"
|
|
@@ -48,11 +57,9 @@ class _Dumper:
|
|
|
48
57
|
assert (
|
|
49
58
|
self._forward_pass_id >= 1
|
|
50
59
|
), "Do you forget to call `dumper.on_forward_pass_start()`?"
|
|
60
|
+
assert self._partial_name is not None
|
|
51
61
|
self._dump_index += 1
|
|
52
62
|
|
|
53
|
-
if self._partial_name is None:
|
|
54
|
-
self._partial_name = _get_partial_name()
|
|
55
|
-
|
|
56
63
|
rank = _get_rank()
|
|
57
64
|
full_kwargs = dict(
|
|
58
65
|
forward_pass_id=self._forward_pass_id,
|
|
@@ -13,7 +13,7 @@ from sglang.srt.disaggregation.mooncake.conn import (
|
|
|
13
13
|
MooncakeKVReceiver,
|
|
14
14
|
MooncakeKVSender,
|
|
15
15
|
)
|
|
16
|
-
from sglang.srt.utils import
|
|
16
|
+
from sglang.srt.utils import get_local_ip_auto
|
|
17
17
|
|
|
18
18
|
logger = logging.getLogger(__name__)
|
|
19
19
|
|
|
@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
|
|
|
21
21
|
class AscendKVManager(MooncakeKVManager):
|
|
22
22
|
def init_engine(self):
|
|
23
23
|
# TransferEngine initialized on ascend.
|
|
24
|
-
local_ip =
|
|
24
|
+
local_ip = get_local_ip_auto()
|
|
25
25
|
self.engine = AscendTransferEngine(
|
|
26
26
|
hostname=local_ip,
|
|
27
27
|
npu_id=self.kv_args.gpu_id,
|
|
@@ -1,10 +1,20 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
|
-
from typing import List
|
|
3
|
+
from typing import List
|
|
4
|
+
|
|
5
|
+
import torch
|
|
4
6
|
|
|
5
7
|
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
|
6
8
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
|
7
9
|
|
|
10
|
+
try:
|
|
11
|
+
from mf_adapter import TransferEngine
|
|
12
|
+
|
|
13
|
+
import_error = None
|
|
14
|
+
except ImportError as e:
|
|
15
|
+
import_error = e
|
|
16
|
+
pass
|
|
17
|
+
|
|
8
18
|
logger = logging.getLogger(__name__)
|
|
9
19
|
|
|
10
20
|
|
|
@@ -13,12 +23,11 @@ class AscendTransferEngine(MooncakeTransferEngine):
|
|
|
13
23
|
def __init__(
|
|
14
24
|
self, hostname: str, npu_id: int, disaggregation_mode: DisaggregationMode
|
|
15
25
|
):
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
except ImportError as e:
|
|
19
|
-
raise ImportError(
|
|
26
|
+
if import_error is not None:
|
|
27
|
+
logger.warning(
|
|
20
28
|
"Please install mf_adapter, for details, see docs/backend/pd_disaggregation.md"
|
|
21
|
-
)
|
|
29
|
+
)
|
|
30
|
+
raise import_error
|
|
22
31
|
|
|
23
32
|
self.engine = TransferEngine()
|
|
24
33
|
self.hostname = hostname
|
|
@@ -37,12 +46,29 @@ class AscendTransferEngine(MooncakeTransferEngine):
|
|
|
37
46
|
self.initialize()
|
|
38
47
|
|
|
39
48
|
def initialize(self) -> None:
|
|
49
|
+
from sglang.srt.layers.dp_attention import (
|
|
50
|
+
get_tensor_model_parallel_world_size,
|
|
51
|
+
get_tp_group,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
transfer_protocol = self._get_transfer_protocol()
|
|
55
|
+
if transfer_protocol is None or transfer_protocol == "sdma":
|
|
56
|
+
trans_op_type = TransferEngine.TransDataOpType.SDMA
|
|
57
|
+
else:
|
|
58
|
+
trans_op_type = TransferEngine.TransDataOpType.DEVICE_RDMA
|
|
59
|
+
"""with device RDMA for PD transfer"""
|
|
60
|
+
tmp_tensor = torch.zeros(1, device="npu")
|
|
61
|
+
output_tensor_list = [
|
|
62
|
+
torch.empty_like(tmp_tensor)
|
|
63
|
+
for _ in range(get_tensor_model_parallel_world_size())
|
|
64
|
+
]
|
|
65
|
+
# Initialize hccl in advance through all_gather to avoid conflicts with rdma initialization.
|
|
66
|
+
torch.distributed.all_gather(
|
|
67
|
+
output_tensor_list, tmp_tensor, group=get_tp_group().device_group
|
|
68
|
+
)
|
|
40
69
|
"""Initialize the ascend transfer instance."""
|
|
41
70
|
ret_value = self.engine.initialize(
|
|
42
|
-
self.store_url,
|
|
43
|
-
self.session_id,
|
|
44
|
-
self.role,
|
|
45
|
-
self.npu_id,
|
|
71
|
+
self.store_url, self.session_id, self.role, self.npu_id, trans_op_type
|
|
46
72
|
)
|
|
47
73
|
if ret_value != 0:
|
|
48
74
|
logger.error("Ascend Transfer Engine initialization failed.")
|
|
@@ -56,3 +82,15 @@ class AscendTransferEngine(MooncakeTransferEngine):
|
|
|
56
82
|
ret_value = -1
|
|
57
83
|
if ret_value != 0:
|
|
58
84
|
logger.debug(f"Ascend memory registration for ptr {ptrs} failed.")
|
|
85
|
+
|
|
86
|
+
@staticmethod
|
|
87
|
+
def _get_transfer_protocol():
|
|
88
|
+
protocol = os.getenv("ASCEND_MF_TRANSFER_PROTOCOL")
|
|
89
|
+
allowed_protocols = {"device_rdma", "sdma"}
|
|
90
|
+
if protocol and protocol.lower() in allowed_protocols:
|
|
91
|
+
return protocol.lower()
|
|
92
|
+
else:
|
|
93
|
+
logger.warning(
|
|
94
|
+
"Invalid or no transfer protocol specified, using default protocol."
|
|
95
|
+
)
|
|
96
|
+
return None
|
|
@@ -20,6 +20,10 @@ class KVArgs:
|
|
|
20
20
|
aux_data_ptrs: List[int]
|
|
21
21
|
aux_data_lens: List[int]
|
|
22
22
|
aux_item_lens: List[int]
|
|
23
|
+
state_data_ptrs: List[int]
|
|
24
|
+
state_data_lens: List[int]
|
|
25
|
+
state_item_lens: List[int]
|
|
26
|
+
state_type: str # "none", "mamba", "swa"
|
|
23
27
|
ib_device: str
|
|
24
28
|
ib_traffic_class: str
|
|
25
29
|
gpu_id: int
|
|
@@ -76,9 +80,13 @@ class BaseKVSender(ABC):
|
|
|
76
80
|
...
|
|
77
81
|
|
|
78
82
|
@abstractmethod
|
|
79
|
-
def send(
|
|
83
|
+
def send(
|
|
84
|
+
self,
|
|
85
|
+
kv_indices: npt.NDArray[np.int32],
|
|
86
|
+
state_indices: Optional[List[int]] = None,
|
|
87
|
+
):
|
|
80
88
|
"""
|
|
81
|
-
Send the kv cache at the given kv indices to the decoder server
|
|
89
|
+
Send the kv cache at the given kv indices and the extra cache/state at the given indices to the decoder server
|
|
82
90
|
"""
|
|
83
91
|
...
|
|
84
92
|
|
|
@@ -108,9 +116,14 @@ class BaseKVReceiver(ABC):
|
|
|
108
116
|
): ...
|
|
109
117
|
|
|
110
118
|
@abstractmethod
|
|
111
|
-
def init(
|
|
119
|
+
def init(
|
|
120
|
+
self,
|
|
121
|
+
kv_indices: npt.NDArray[np.int32],
|
|
122
|
+
aux_index: Optional[int] = None,
|
|
123
|
+
state_indices: Optional[List[int]] = None,
|
|
124
|
+
):
|
|
112
125
|
"""
|
|
113
|
-
Notify the prefill server about the kv indices
|
|
126
|
+
Notify the prefill server about the kv indices, aux index, and state_indices.
|
|
114
127
|
"""
|
|
115
128
|
...
|
|
116
129
|
|