sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +330 -156
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +8 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +134 -23
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +70 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +66 -66
- sglang/srt/entrypoints/grpc_server.py +431 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +120 -8
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +42 -4
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +18 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +4 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +12 -8
- sglang/srt/layers/attention/flashinfer_backend.py +248 -21
- sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +11 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +45 -15
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +147 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
- sglang/srt/layers/moe/ep_moe/layer.py +119 -397
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +17 -1
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +5 -30
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +673 -16
- sglang/srt/layers/sampler.py +36 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +66 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +399 -499
- sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +378 -90
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +435 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +28 -23
- sglang/srt/model_executor/model_runner.py +379 -139
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +273 -98
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +14 -37
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +5 -5
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3_moe.py +19 -35
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +6 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +577 -73
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +38 -28
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +192 -47
- sglang/srt/utils/hf_transformers_utils.py +40 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +232 -99
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -1,9 +1,6 @@
|
|
|
1
|
-
from dataclasses import astuple, dataclass
|
|
2
|
-
from functools import lru_cache
|
|
3
1
|
from typing import Optional, Union
|
|
4
2
|
|
|
5
3
|
import torch
|
|
6
|
-
import torch.nn.functional as F
|
|
7
4
|
|
|
8
5
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
|
9
6
|
from sglang.srt.layers.attention.fla.chunk import chunk_gated_delta_rule
|
|
@@ -14,14 +11,21 @@ from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (
|
|
|
14
11
|
fused_sigmoid_gating_delta_rule_update,
|
|
15
12
|
)
|
|
16
13
|
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
|
|
14
|
+
PAD_SLOT_ID,
|
|
17
15
|
causal_conv1d_fn,
|
|
18
16
|
causal_conv1d_update,
|
|
19
17
|
)
|
|
18
|
+
from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
|
|
19
|
+
from sglang.srt.layers.attention.mamba.mamba2_metadata import (
|
|
20
|
+
ForwardMetadata,
|
|
21
|
+
Mamba2Metadata,
|
|
22
|
+
)
|
|
20
23
|
from sglang.srt.layers.radix_attention import RadixAttention
|
|
21
|
-
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool
|
|
24
|
+
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, MambaPool
|
|
22
25
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
|
23
26
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
|
24
27
|
from sglang.srt.models.qwen3_next import fused_gdn_gating
|
|
28
|
+
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
|
|
25
29
|
from sglang.srt.speculative.spec_info import SpecInput
|
|
26
30
|
from sglang.srt.utils import is_cuda, is_npu
|
|
27
31
|
|
|
@@ -47,18 +51,10 @@ elif is_npu():
|
|
|
47
51
|
causal_conv1d_update = causal_conv1d_update_npu
|
|
48
52
|
|
|
49
53
|
|
|
50
|
-
|
|
51
|
-
class ForwardMetadata:
|
|
52
|
-
query_start_loc: Optional[torch.Tensor]
|
|
53
|
-
mamba_cache_indices: torch.Tensor
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
class MambaAttnBackend(AttentionBackend):
|
|
57
|
-
"""Attention backend using Mamba kernel."""
|
|
58
|
-
|
|
54
|
+
class MambaAttnBackendBase(AttentionBackend):
|
|
59
55
|
def __init__(self, model_runner: ModelRunner):
|
|
60
56
|
super().__init__()
|
|
61
|
-
self.pad_slot_id =
|
|
57
|
+
self.pad_slot_id = PAD_SLOT_ID
|
|
62
58
|
self.device = model_runner.device
|
|
63
59
|
self.req_to_token_pool: HybridReqToTokenPool = model_runner.req_to_token_pool
|
|
64
60
|
self.forward_metadata: ForwardMetadata = None
|
|
@@ -67,7 +63,7 @@ class MambaAttnBackend(AttentionBackend):
|
|
|
67
63
|
self.cached_cuda_graph_decode_query_start_loc: torch.Tensor = None
|
|
68
64
|
self.cached_cuda_graph_verify_query_start_loc: torch.Tensor = None
|
|
69
65
|
|
|
70
|
-
def
|
|
66
|
+
def _forward_metadata(self, forward_batch: ForwardBatch):
|
|
71
67
|
bs = forward_batch.batch_size
|
|
72
68
|
|
|
73
69
|
if forward_batch.forward_mode.is_decode_or_idle():
|
|
@@ -97,11 +93,43 @@ class MambaAttnBackend(AttentionBackend):
|
|
|
97
93
|
mamba_cache_indices = self.req_to_token_pool.get_mamba_indices(
|
|
98
94
|
forward_batch.req_pool_indices
|
|
99
95
|
)
|
|
100
|
-
|
|
96
|
+
return ForwardMetadata(
|
|
101
97
|
query_start_loc=query_start_loc,
|
|
102
98
|
mamba_cache_indices=mamba_cache_indices,
|
|
103
99
|
)
|
|
104
100
|
|
|
101
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
|
102
|
+
self.forward_metadata = self._forward_metadata(forward_batch)
|
|
103
|
+
|
|
104
|
+
def init_forward_metadata_capture_cuda_graph(
|
|
105
|
+
self,
|
|
106
|
+
bs: int,
|
|
107
|
+
num_tokens: int,
|
|
108
|
+
req_pool_indices: torch.Tensor,
|
|
109
|
+
seq_lens: torch.Tensor,
|
|
110
|
+
encoder_lens: Optional[torch.Tensor],
|
|
111
|
+
forward_mode: ForwardMode,
|
|
112
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
|
113
|
+
):
|
|
114
|
+
self.forward_metadata = self._capture_metadata(
|
|
115
|
+
bs, req_pool_indices, forward_mode
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
def init_forward_metadata_replay_cuda_graph(
|
|
119
|
+
self,
|
|
120
|
+
bs: int,
|
|
121
|
+
req_pool_indices: torch.Tensor,
|
|
122
|
+
seq_lens: torch.Tensor,
|
|
123
|
+
seq_lens_sum: int,
|
|
124
|
+
encoder_lens: Optional[torch.Tensor],
|
|
125
|
+
forward_mode: ForwardMode,
|
|
126
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
|
127
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
|
128
|
+
):
|
|
129
|
+
self.forward_metadata = self._replay_metadata(
|
|
130
|
+
bs, req_pool_indices, forward_mode, spec_info, seq_lens_cpu
|
|
131
|
+
)
|
|
132
|
+
|
|
105
133
|
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
|
106
134
|
assert (
|
|
107
135
|
max_num_tokens % max_bs == 0
|
|
@@ -127,15 +155,8 @@ class MambaAttnBackend(AttentionBackend):
|
|
|
127
155
|
device=self.device,
|
|
128
156
|
)
|
|
129
157
|
|
|
130
|
-
def
|
|
131
|
-
self,
|
|
132
|
-
bs: int,
|
|
133
|
-
num_tokens: int,
|
|
134
|
-
req_pool_indices: torch.Tensor,
|
|
135
|
-
seq_lens: torch.Tensor,
|
|
136
|
-
encoder_lens: Optional[torch.Tensor],
|
|
137
|
-
forward_mode: ForwardMode,
|
|
138
|
-
spec_info: Optional[SpecInput],
|
|
158
|
+
def _capture_metadata(
|
|
159
|
+
self, bs: int, req_pool_indices: torch.Tensor, forward_mode: ForwardMode
|
|
139
160
|
):
|
|
140
161
|
if forward_mode.is_decode_or_idle():
|
|
141
162
|
self.query_start_loc_list[bs - 1].copy_(
|
|
@@ -149,18 +170,15 @@ class MambaAttnBackend(AttentionBackend):
|
|
|
149
170
|
raise ValueError(f"Invalid forward mode: {forward_mode=}")
|
|
150
171
|
mamba_indices = self.req_to_token_pool.get_mamba_indices(req_pool_indices)
|
|
151
172
|
self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices)
|
|
152
|
-
|
|
173
|
+
return ForwardMetadata(
|
|
153
174
|
query_start_loc=self.query_start_loc_list[bs - 1],
|
|
154
175
|
mamba_cache_indices=self.state_indices_list[bs - 1],
|
|
155
176
|
)
|
|
156
177
|
|
|
157
|
-
def
|
|
178
|
+
def _replay_metadata(
|
|
158
179
|
self,
|
|
159
180
|
bs: int,
|
|
160
181
|
req_pool_indices: torch.Tensor,
|
|
161
|
-
seq_lens: torch.Tensor,
|
|
162
|
-
seq_lens_sum: int,
|
|
163
|
-
encoder_lens: Optional[torch.Tensor],
|
|
164
182
|
forward_mode: ForwardMode,
|
|
165
183
|
spec_info: Optional[SpecInput],
|
|
166
184
|
seq_lens_cpu: Optional[torch.Tensor],
|
|
@@ -200,7 +218,7 @@ class MambaAttnBackend(AttentionBackend):
|
|
|
200
218
|
else:
|
|
201
219
|
raise ValueError(f"Invalid forward mode: {forward_mode=}")
|
|
202
220
|
|
|
203
|
-
|
|
221
|
+
return ForwardMetadata(
|
|
204
222
|
query_start_loc=self.query_start_loc_list[bs - 1],
|
|
205
223
|
mamba_cache_indices=self.state_indices_list[bs - 1],
|
|
206
224
|
)
|
|
@@ -208,6 +226,10 @@ class MambaAttnBackend(AttentionBackend):
|
|
|
208
226
|
def get_cuda_graph_seq_len_fill_value(self):
|
|
209
227
|
return 1 # Mamba attn does not use seq lens to index kv cache
|
|
210
228
|
|
|
229
|
+
|
|
230
|
+
class GDNAttnBackend(MambaAttnBackendBase):
|
|
231
|
+
"""Attention backend using Mamba kernel."""
|
|
232
|
+
|
|
211
233
|
def forward_decode(
|
|
212
234
|
self,
|
|
213
235
|
q: torch.Tensor,
|
|
@@ -233,9 +255,9 @@ class MambaAttnBackend(AttentionBackend):
|
|
|
233
255
|
dt_bias = kwargs["dt_bias"]
|
|
234
256
|
layer_id = kwargs["layer_id"]
|
|
235
257
|
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
258
|
+
layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id)
|
|
259
|
+
conv_states = layer_cache.conv
|
|
260
|
+
ssm_states = layer_cache.temporal
|
|
239
261
|
query_start_loc = self.forward_metadata.query_start_loc
|
|
240
262
|
cache_indices = self.forward_metadata.mamba_cache_indices
|
|
241
263
|
|
|
@@ -313,13 +335,13 @@ class MambaAttnBackend(AttentionBackend):
|
|
|
313
335
|
query_start_loc = self.forward_metadata.query_start_loc
|
|
314
336
|
cache_indices = self.forward_metadata.mamba_cache_indices
|
|
315
337
|
|
|
338
|
+
mamba_cache_params = self.req_to_token_pool.mamba2_layer_cache(layer_id)
|
|
339
|
+
conv_states = mamba_cache_params.conv
|
|
340
|
+
ssm_states = mamba_cache_params.temporal
|
|
316
341
|
if is_target_verify:
|
|
317
|
-
(
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
intermediate_state_cache,
|
|
321
|
-
intermediate_conv_window_cache,
|
|
322
|
-
) = self.req_to_token_pool.get_mamba_params(layer_id)
|
|
342
|
+
assert isinstance(mamba_cache_params, MambaPool.SpeculativeState)
|
|
343
|
+
intermediate_state_cache = mamba_cache_params.intermediate_ssm
|
|
344
|
+
intermediate_conv_window_cache = mamba_cache_params.intermediate_conv_window
|
|
323
345
|
has_initial_states = torch.ones(
|
|
324
346
|
seq_len // forward_batch.spec_info.draft_token_num,
|
|
325
347
|
dtype=torch.bool,
|
|
@@ -327,9 +349,6 @@ class MambaAttnBackend(AttentionBackend):
|
|
|
327
349
|
)
|
|
328
350
|
conv_states_to_use = conv_states.clone()
|
|
329
351
|
else:
|
|
330
|
-
conv_states, ssm_states, *rest = self.req_to_token_pool.get_mamba_params(
|
|
331
|
-
layer_id
|
|
332
|
-
)
|
|
333
352
|
has_initial_states = forward_batch.extend_prefix_lens > 0
|
|
334
353
|
conv_states_to_use = conv_states
|
|
335
354
|
|
|
@@ -424,16 +443,100 @@ class MambaAttnBackend(AttentionBackend):
|
|
|
424
443
|
return core_attn_out
|
|
425
444
|
|
|
426
445
|
|
|
446
|
+
class Mamba2AttnBackend(MambaAttnBackendBase):
|
|
447
|
+
"""Attention backend wrapper for Mamba2Mixer kernels."""
|
|
448
|
+
|
|
449
|
+
def __init__(self, model_runner: ModelRunner):
|
|
450
|
+
super().__init__(model_runner)
|
|
451
|
+
config = model_runner.mamba2_config
|
|
452
|
+
assert config is not None
|
|
453
|
+
self.mamba_chunk_size = config.mamba_chunk_size
|
|
454
|
+
|
|
455
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
|
456
|
+
metadata = self._forward_metadata(forward_batch)
|
|
457
|
+
self.forward_metadata = Mamba2Metadata.prepare_mixed(
|
|
458
|
+
metadata.query_start_loc,
|
|
459
|
+
metadata.mamba_cache_indices,
|
|
460
|
+
self.mamba_chunk_size,
|
|
461
|
+
forward_batch,
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
def init_forward_metadata_capture_cuda_graph(
|
|
465
|
+
self,
|
|
466
|
+
bs: int,
|
|
467
|
+
num_tokens: int,
|
|
468
|
+
req_pool_indices: torch.Tensor,
|
|
469
|
+
seq_lens: torch.Tensor,
|
|
470
|
+
encoder_lens: Optional[torch.Tensor],
|
|
471
|
+
forward_mode: ForwardMode,
|
|
472
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
|
473
|
+
):
|
|
474
|
+
metadata = self._capture_metadata(bs, req_pool_indices, forward_mode)
|
|
475
|
+
self.forward_metadata = Mamba2Metadata.prepare_decode(
|
|
476
|
+
metadata.query_start_loc, metadata.mamba_cache_indices, seq_lens
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
def init_forward_metadata_replay_cuda_graph(
|
|
480
|
+
self,
|
|
481
|
+
bs: int,
|
|
482
|
+
req_pool_indices: torch.Tensor,
|
|
483
|
+
seq_lens: torch.Tensor,
|
|
484
|
+
seq_lens_sum: int,
|
|
485
|
+
encoder_lens: Optional[torch.Tensor],
|
|
486
|
+
forward_mode: ForwardMode,
|
|
487
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
|
488
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
|
489
|
+
):
|
|
490
|
+
metadata = self._replay_metadata(
|
|
491
|
+
bs, req_pool_indices, forward_mode, spec_info, seq_lens_cpu
|
|
492
|
+
)
|
|
493
|
+
self.forward_metadata = Mamba2Metadata.prepare_decode(
|
|
494
|
+
metadata.query_start_loc, metadata.mamba_cache_indices, seq_lens
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
def forward(
|
|
498
|
+
self,
|
|
499
|
+
mixer: MambaMixer2,
|
|
500
|
+
hidden_states: torch.Tensor,
|
|
501
|
+
output: torch.Tensor,
|
|
502
|
+
layer_id: int,
|
|
503
|
+
mup_vector: Optional[torch.Tensor] = None,
|
|
504
|
+
use_triton_causal_conv: bool = False,
|
|
505
|
+
):
|
|
506
|
+
assert isinstance(self.forward_metadata, Mamba2Metadata)
|
|
507
|
+
layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id)
|
|
508
|
+
return mixer.forward(
|
|
509
|
+
hidden_states=hidden_states,
|
|
510
|
+
output=output,
|
|
511
|
+
layer_cache=layer_cache,
|
|
512
|
+
metadata=self.forward_metadata,
|
|
513
|
+
mup_vector=mup_vector,
|
|
514
|
+
use_triton_causal_conv=use_triton_causal_conv,
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
def forward_decode(self, *args, **kwargs):
|
|
518
|
+
raise NotImplementedError(
|
|
519
|
+
"Mamba2AttnBackend's forward is called directly instead of through HybridLinearAttnBackend, as it supports mixed prefill and decode"
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
def forward_extend(self, *args, **kwargs):
|
|
523
|
+
raise NotImplementedError(
|
|
524
|
+
"Mamba2AttnBackend's forward is called directly instead of through HybridLinearAttnBackend, as it supports mixed prefill and decode"
|
|
525
|
+
)
|
|
526
|
+
|
|
527
|
+
|
|
427
528
|
class HybridLinearAttnBackend(AttentionBackend):
|
|
428
|
-
"""
|
|
529
|
+
"""Manages a full and linear attention backend"""
|
|
429
530
|
|
|
430
531
|
def __init__(
|
|
431
532
|
self,
|
|
432
533
|
full_attn_backend: AttentionBackend,
|
|
433
|
-
linear_attn_backend:
|
|
534
|
+
linear_attn_backend: MambaAttnBackendBase,
|
|
434
535
|
full_attn_layers: list[int],
|
|
435
536
|
):
|
|
436
537
|
self.full_attn_layers = full_attn_layers
|
|
538
|
+
self.full_attn_backend = full_attn_backend
|
|
539
|
+
self.linear_attn_backend = linear_attn_backend
|
|
437
540
|
self.attn_backend_list = [full_attn_backend, linear_attn_backend]
|
|
438
541
|
|
|
439
542
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
|
@@ -489,7 +592,7 @@ class HybridLinearAttnBackend(AttentionBackend):
|
|
|
489
592
|
)
|
|
490
593
|
|
|
491
594
|
def get_cuda_graph_seq_len_fill_value(self):
|
|
492
|
-
return self.
|
|
595
|
+
return self.full_attn_backend.get_cuda_graph_seq_len_fill_value()
|
|
493
596
|
|
|
494
597
|
def forward_decode(
|
|
495
598
|
self,
|
|
@@ -503,10 +606,10 @@ class HybridLinearAttnBackend(AttentionBackend):
|
|
|
503
606
|
):
|
|
504
607
|
layer_id = layer.layer_id if layer else kwargs["layer_id"]
|
|
505
608
|
if layer_id in self.full_attn_layers:
|
|
506
|
-
return self.
|
|
609
|
+
return self.full_attn_backend.forward_decode(
|
|
507
610
|
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
|
508
611
|
)
|
|
509
|
-
return self.
|
|
612
|
+
return self.linear_attn_backend.forward_decode(
|
|
510
613
|
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
|
511
614
|
)
|
|
512
615
|
|
|
@@ -522,10 +625,10 @@ class HybridLinearAttnBackend(AttentionBackend):
|
|
|
522
625
|
):
|
|
523
626
|
layer_id = layer.layer_id if layer else kwargs["layer_id"]
|
|
524
627
|
if layer_id in self.full_attn_layers:
|
|
525
|
-
return self.
|
|
628
|
+
return self.full_attn_backend.forward_extend(
|
|
526
629
|
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
|
527
630
|
)
|
|
528
|
-
return self.
|
|
631
|
+
return self.linear_attn_backend.forward_extend(
|
|
529
632
|
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
|
530
633
|
)
|
|
531
634
|
|
|
@@ -568,20 +671,20 @@ class HybridLinearAttnBackend(AttentionBackend):
|
|
|
568
671
|
def update_mamba_state_after_mtp_verify(self, accepted_length, model):
|
|
569
672
|
request_number = accepted_length.shape[0]
|
|
570
673
|
|
|
571
|
-
state_indices_tensor =
|
|
572
|
-
|
|
573
|
-
|
|
674
|
+
state_indices_tensor = (
|
|
675
|
+
self.linear_attn_backend.forward_metadata.mamba_cache_indices[
|
|
676
|
+
:request_number
|
|
677
|
+
]
|
|
678
|
+
)
|
|
574
679
|
|
|
575
|
-
mamba_caches =
|
|
576
|
-
|
|
577
|
-
|
|
680
|
+
mamba_caches = (
|
|
681
|
+
self.linear_attn_backend.req_to_token_pool.get_speculative_mamba2_params_all_layers()
|
|
682
|
+
)
|
|
578
683
|
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
intermediate_conv_window_cache,
|
|
584
|
-
) = mamba_caches
|
|
684
|
+
conv_states = mamba_caches.conv
|
|
685
|
+
ssm_states = mamba_caches.temporal
|
|
686
|
+
intermediate_state_cache = mamba_caches.intermediate_ssm
|
|
687
|
+
intermediate_conv_window_cache = mamba_caches.intermediate_conv_window
|
|
585
688
|
|
|
586
689
|
# SSM state updates (chunked to reduce peak memory)
|
|
587
690
|
valid_mask = accepted_length > 0
|
|
@@ -4,13 +4,12 @@
|
|
|
4
4
|
|
|
5
5
|
from typing import List, Optional, Union
|
|
6
6
|
|
|
7
|
-
import numpy as np
|
|
8
7
|
import torch
|
|
9
|
-
|
|
10
|
-
PAD_SLOT_ID = -1
|
|
11
8
|
import triton
|
|
12
9
|
import triton.language as tl
|
|
13
10
|
|
|
11
|
+
PAD_SLOT_ID = -1
|
|
12
|
+
|
|
14
13
|
|
|
15
14
|
@triton.jit()
|
|
16
15
|
def _causal_conv1d_fwd_kernel( # continuous batching
|
|
@@ -672,7 +671,9 @@ def _causal_conv1d_update_kernel(
|
|
|
672
671
|
+ (conv_state_batch_coord * stride_conv_state_seq)
|
|
673
672
|
+ conv_state_token_offset * stride_conv_state_tok
|
|
674
673
|
+ (idx_feats * stride_conv_state_dim)[None, :]
|
|
675
|
-
+ ((idx_tokens + 1) * stride_conv_state_tok)[
|
|
674
|
+
+ ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * stride_conv_state_tok)[
|
|
675
|
+
:, None
|
|
676
|
+
]
|
|
676
677
|
) # [BLOCK_M, BLOCK_N]
|
|
677
678
|
mask = (
|
|
678
679
|
(conv_state_batch_coord < num_cache_lines)
|
|
@@ -897,7 +898,10 @@ def causal_conv1d_update(
|
|
|
897
898
|
stride_state_indices = (
|
|
898
899
|
conv_state_indices.stride(0) if conv_state_indices is not None else 0
|
|
899
900
|
)
|
|
900
|
-
|
|
901
|
+
if num_accepted_tokens is not None:
|
|
902
|
+
state_len = width - 1 + (seqlen - 1) # effective state_len needed
|
|
903
|
+
else:
|
|
904
|
+
state_len = width - 1
|
|
901
905
|
np2_statelen = triton.next_power_of_2(state_len)
|
|
902
906
|
|
|
903
907
|
def grid(META):
|