sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -28,8 +28,10 @@ from sglang.srt.layers.dp_attention import get_attention_tp_size
|
|
28
28
|
from sglang.srt.layers.radix_attention import AttentionType
|
29
29
|
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
30
30
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
31
|
-
from sglang.srt.speculative.
|
31
|
+
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
|
32
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
32
33
|
from sglang.srt.utils import (
|
34
|
+
get_int_env_var,
|
33
35
|
is_flashinfer_available,
|
34
36
|
is_sm100_supported,
|
35
37
|
next_power_of_2,
|
@@ -39,11 +41,13 @@ if TYPE_CHECKING:
|
|
39
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
40
42
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
41
43
|
|
44
|
+
|
42
45
|
if is_flashinfer_available():
|
43
46
|
from flashinfer import (
|
44
47
|
BatchDecodeWithPagedKVCacheWrapper,
|
45
48
|
BatchPrefillWithPagedKVCacheWrapper,
|
46
49
|
BatchPrefillWithRaggedKVCacheWrapper,
|
50
|
+
fast_decode_plan,
|
47
51
|
)
|
48
52
|
from flashinfer.cascade import merge_state
|
49
53
|
from flashinfer.decode import _get_range_buf, get_seq_lens
|
@@ -122,12 +126,33 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
122
126
|
):
|
123
127
|
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
|
124
128
|
|
129
|
+
# When deterministic inference is enabled, tensor cores should be used for decode
|
130
|
+
# Also set split tile sizes for prefill and decode from environment variables, and disable kv split for cuda graph
|
131
|
+
# More information can be found here: https://github.com/flashinfer-ai/flashinfer/pull/1675
|
132
|
+
self.enable_deterministic = (
|
133
|
+
model_runner.server_args.enable_deterministic_inference
|
134
|
+
)
|
135
|
+
self.prefill_split_tile_size = None
|
136
|
+
self.decode_split_tile_size = None
|
137
|
+
self.disable_cuda_graph_kv_split = False
|
138
|
+
if self.enable_deterministic:
|
139
|
+
self.decode_use_tensor_cores = True
|
140
|
+
self.prefill_split_tile_size = get_int_env_var(
|
141
|
+
"SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096
|
142
|
+
)
|
143
|
+
self.decode_split_tile_size = get_int_env_var(
|
144
|
+
"SGLANG_FLASHINFER_DECODE_SPLIT_TILE_SIZE", 2048
|
145
|
+
)
|
146
|
+
self.disable_cuda_graph_kv_split = True
|
147
|
+
global_config.flashinfer_workspace_size = 2048 * 1024 * 1024
|
148
|
+
|
125
149
|
# Allocate buffers
|
126
150
|
global global_workspace_buffer
|
127
151
|
if global_workspace_buffer is None:
|
128
152
|
# different from flashinfer zero_init_global_workspace_buffer
|
153
|
+
global_workspace_size = global_config.flashinfer_workspace_size
|
129
154
|
global_workspace_buffer = torch.empty(
|
130
|
-
|
155
|
+
global_workspace_size,
|
131
156
|
dtype=torch.uint8,
|
132
157
|
device=model_runner.device,
|
133
158
|
)
|
@@ -218,6 +243,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
218
243
|
decode_wrappers=self.decode_wrappers,
|
219
244
|
encoder_lens=forward_batch.encoder_lens,
|
220
245
|
spec_info=forward_batch.spec_info,
|
246
|
+
fixed_split_size=self.decode_split_tile_size,
|
247
|
+
disable_split_kv=False,
|
221
248
|
)
|
222
249
|
self.forward_metadata = DecodeMetadata(self.decode_wrappers)
|
223
250
|
elif forward_batch.forward_mode.is_draft_extend():
|
@@ -257,7 +284,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
257
284
|
use_ragged = False
|
258
285
|
extend_no_prefix = False
|
259
286
|
else:
|
260
|
-
use_ragged =
|
287
|
+
use_ragged = not self.enable_deterministic
|
261
288
|
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
262
289
|
|
263
290
|
self.indices_updater_prefill.update(
|
@@ -270,6 +297,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
270
297
|
use_ragged=use_ragged,
|
271
298
|
encoder_lens=forward_batch.encoder_lens,
|
272
299
|
spec_info=None,
|
300
|
+
fixed_split_size=self.prefill_split_tile_size,
|
273
301
|
)
|
274
302
|
self.forward_metadata = PrefillMetadata(
|
275
303
|
self.prefill_wrappers_paged, use_ragged, extend_no_prefix
|
@@ -317,7 +345,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
317
345
|
seq_lens: torch.Tensor,
|
318
346
|
encoder_lens: Optional[torch.Tensor],
|
319
347
|
forward_mode: ForwardMode,
|
320
|
-
spec_info: Optional[
|
348
|
+
spec_info: Optional[SpecInput],
|
321
349
|
):
|
322
350
|
if forward_mode.is_decode_or_idle():
|
323
351
|
decode_wrappers = []
|
@@ -344,6 +372,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
344
372
|
decode_wrappers=decode_wrappers,
|
345
373
|
encoder_lens=encoder_lens,
|
346
374
|
spec_info=spec_info,
|
375
|
+
fixed_split_size=None,
|
376
|
+
disable_split_kv=self.disable_cuda_graph_kv_split,
|
347
377
|
)
|
348
378
|
self.decode_cuda_graph_metadata[bs] = decode_wrappers
|
349
379
|
self.forward_metadata = DecodeMetadata(decode_wrappers)
|
@@ -422,7 +452,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
422
452
|
seq_lens_sum: int,
|
423
453
|
encoder_lens: Optional[torch.Tensor],
|
424
454
|
forward_mode: ForwardMode,
|
425
|
-
spec_info: Optional[
|
455
|
+
spec_info: Optional[SpecInput],
|
426
456
|
seq_lens_cpu: Optional[torch.Tensor],
|
427
457
|
):
|
428
458
|
if forward_mode.is_decode_or_idle():
|
@@ -434,6 +464,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
434
464
|
decode_wrappers=self.decode_cuda_graph_metadata[bs],
|
435
465
|
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
436
466
|
spec_info=spec_info,
|
467
|
+
fixed_split_size=None,
|
468
|
+
disable_split_kv=self.disable_cuda_graph_kv_split,
|
437
469
|
)
|
438
470
|
elif forward_mode.is_target_verify():
|
439
471
|
self.indices_updater_prefill.update(
|
@@ -638,7 +670,9 @@ class FlashInferIndicesUpdaterDecode:
|
|
638
670
|
seq_lens_sum: int,
|
639
671
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
640
672
|
encoder_lens: Optional[torch.Tensor],
|
641
|
-
spec_info: Optional[
|
673
|
+
spec_info: Optional[SpecInput],
|
674
|
+
fixed_split_size: Optional[int] = None,
|
675
|
+
disable_split_kv: Optional[bool] = None,
|
642
676
|
):
|
643
677
|
# Keep the signature for type checking. It will be assigned during runtime.
|
644
678
|
raise NotImplementedError()
|
@@ -651,7 +685,9 @@ class FlashInferIndicesUpdaterDecode:
|
|
651
685
|
seq_lens_sum: int,
|
652
686
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
653
687
|
encoder_lens: Optional[torch.Tensor],
|
654
|
-
spec_info: Optional[
|
688
|
+
spec_info: Optional[SpecInput],
|
689
|
+
fixed_split_size: Optional[int] = None,
|
690
|
+
disable_split_kv: Optional[bool] = None,
|
655
691
|
):
|
656
692
|
decode_wrappers = decode_wrappers or self.decode_wrappers
|
657
693
|
self.call_begin_forward(
|
@@ -663,6 +699,8 @@ class FlashInferIndicesUpdaterDecode:
|
|
663
699
|
None,
|
664
700
|
spec_info,
|
665
701
|
seq_lens_cpu,
|
702
|
+
fixed_split_size=fixed_split_size,
|
703
|
+
disable_split_kv=disable_split_kv,
|
666
704
|
)
|
667
705
|
|
668
706
|
def update_sliding_window(
|
@@ -673,7 +711,9 @@ class FlashInferIndicesUpdaterDecode:
|
|
673
711
|
seq_lens_sum: int,
|
674
712
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
675
713
|
encoder_lens: Optional[torch.Tensor],
|
676
|
-
spec_info: Optional[
|
714
|
+
spec_info: Optional[SpecInput],
|
715
|
+
fixed_split_size: Optional[int] = None,
|
716
|
+
disable_split_kv: Optional[bool] = None,
|
677
717
|
):
|
678
718
|
assert self.sliding_window_size is not None
|
679
719
|
for wrapper_id in range(2):
|
@@ -721,7 +761,9 @@ class FlashInferIndicesUpdaterDecode:
|
|
721
761
|
seq_lens_sum: int,
|
722
762
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
723
763
|
encoder_lens: Optional[torch.Tensor],
|
724
|
-
spec_info: Optional[
|
764
|
+
spec_info: Optional[SpecInput],
|
765
|
+
fixed_split_size: Optional[int] = None,
|
766
|
+
disable_split_kv: Optional[bool] = None,
|
725
767
|
):
|
726
768
|
for wrapper_id in range(2):
|
727
769
|
if wrapper_id == 0:
|
@@ -753,9 +795,11 @@ class FlashInferIndicesUpdaterDecode:
|
|
753
795
|
paged_kernel_lens_sum: int,
|
754
796
|
kv_indptr: torch.Tensor,
|
755
797
|
kv_start_idx: torch.Tensor,
|
756
|
-
spec_info: Optional[
|
798
|
+
spec_info: Optional[SpecInput],
|
757
799
|
seq_lens_cpu: Optional[torch.Tensor],
|
758
800
|
use_sliding_window_kv_pool: bool = False,
|
801
|
+
fixed_split_size: Optional[int] = None,
|
802
|
+
disable_split_kv: Optional[bool] = None,
|
759
803
|
):
|
760
804
|
if spec_info is None:
|
761
805
|
bs = len(req_pool_indices)
|
@@ -799,19 +843,51 @@ class FlashInferIndicesUpdaterDecode:
|
|
799
843
|
global_override_indptr_cpu[0] = 0
|
800
844
|
global_override_indptr_cpu[1 : bs + 1] = torch.cumsum(seq_lens_cpu, dim=0)
|
801
845
|
|
802
|
-
wrapper
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
self.num_kv_heads,
|
808
|
-
self.head_dim,
|
809
|
-
1,
|
810
|
-
data_type=self.data_type,
|
811
|
-
q_data_type=self.q_data_type,
|
812
|
-
non_blocking=True,
|
846
|
+
# Check if this specific wrapper's begin_forward has been replaced with fast_decode_plan
|
847
|
+
# by checking if it's a partial function with fast_decode_plan as the func
|
848
|
+
wrapper_uses_fast_decode_plan = (
|
849
|
+
hasattr(wrapper.begin_forward, "func")
|
850
|
+
and wrapper.begin_forward.func == fast_decode_plan
|
813
851
|
)
|
814
852
|
|
853
|
+
if wrapper_uses_fast_decode_plan:
|
854
|
+
# When begin_forward is replaced with fast_decode_plan, pass global_override_indptr_cpu
|
855
|
+
wrapper.begin_forward(
|
856
|
+
kv_indptr,
|
857
|
+
kv_indices,
|
858
|
+
self.kv_last_page_len[:bs],
|
859
|
+
self.num_qo_heads,
|
860
|
+
self.num_kv_heads,
|
861
|
+
self.head_dim,
|
862
|
+
1,
|
863
|
+
data_type=self.data_type,
|
864
|
+
q_data_type=self.q_data_type,
|
865
|
+
non_blocking=True,
|
866
|
+
fixed_split_size=fixed_split_size,
|
867
|
+
disable_split_kv=(
|
868
|
+
disable_split_kv if disable_split_kv is not None else False
|
869
|
+
),
|
870
|
+
global_override_indptr_cpu=global_override_indptr_cpu,
|
871
|
+
)
|
872
|
+
else:
|
873
|
+
# When using original begin_forward, don't pass global_override_indptr_cpu
|
874
|
+
wrapper.begin_forward(
|
875
|
+
kv_indptr,
|
876
|
+
kv_indices,
|
877
|
+
self.kv_last_page_len[:bs],
|
878
|
+
self.num_qo_heads,
|
879
|
+
self.num_kv_heads,
|
880
|
+
self.head_dim,
|
881
|
+
1,
|
882
|
+
data_type=self.data_type,
|
883
|
+
q_data_type=self.q_data_type,
|
884
|
+
non_blocking=True,
|
885
|
+
fixed_split_size=fixed_split_size,
|
886
|
+
disable_split_kv=(
|
887
|
+
disable_split_kv if disable_split_kv is not None else False
|
888
|
+
),
|
889
|
+
)
|
890
|
+
|
815
891
|
if locally_override:
|
816
892
|
global_override_indptr_cpu = None
|
817
893
|
|
@@ -858,7 +934,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|
858
934
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
859
935
|
use_ragged: bool,
|
860
936
|
encoder_lens: Optional[torch.Tensor],
|
861
|
-
spec_info: Optional[
|
937
|
+
spec_info: Optional[SpecInput],
|
938
|
+
fixed_split_size: Optional[int] = None,
|
862
939
|
):
|
863
940
|
# Keep the signature for type checking. It will be assigned during runtime.
|
864
941
|
raise NotImplementedError()
|
@@ -873,7 +950,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|
873
950
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
874
951
|
use_ragged: bool,
|
875
952
|
encoder_lens: Optional[torch.Tensor],
|
876
|
-
spec_info: Optional[
|
953
|
+
spec_info: Optional[SpecInput],
|
954
|
+
fixed_split_size: Optional[int] = None,
|
877
955
|
):
|
878
956
|
if use_ragged:
|
879
957
|
# TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
|
@@ -897,6 +975,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
897
975
|
self.qo_indptr[0],
|
898
976
|
use_ragged,
|
899
977
|
spec_info,
|
978
|
+
fixed_split_size=fixed_split_size,
|
900
979
|
)
|
901
980
|
|
902
981
|
def update_sliding_window(
|
@@ -909,7 +988,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|
909
988
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
910
989
|
use_ragged: bool,
|
911
990
|
encoder_lens: Optional[torch.Tensor],
|
912
|
-
spec_info: Optional[
|
991
|
+
spec_info: Optional[SpecInput],
|
992
|
+
fixed_split_size: Optional[int] = None,
|
913
993
|
):
|
914
994
|
for wrapper_id in range(2):
|
915
995
|
if wrapper_id == 0:
|
@@ -955,7 +1035,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|
955
1035
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
956
1036
|
use_ragged: bool,
|
957
1037
|
encoder_lens: Optional[torch.Tensor],
|
958
|
-
spec_info: Optional[
|
1038
|
+
spec_info: Optional[SpecInput],
|
1039
|
+
fixed_split_size: Optional[int] = None,
|
959
1040
|
):
|
960
1041
|
for wrapper_id in range(2):
|
961
1042
|
if wrapper_id == 0:
|
@@ -997,8 +1078,9 @@ class FlashInferIndicesUpdaterPrefill:
|
|
997
1078
|
kv_indptr: torch.Tensor,
|
998
1079
|
qo_indptr: torch.Tensor,
|
999
1080
|
use_ragged: bool,
|
1000
|
-
spec_info: Optional[
|
1081
|
+
spec_info: Optional[SpecInput],
|
1001
1082
|
use_sliding_window_kv_pool: bool = False,
|
1083
|
+
fixed_split_size: Optional[int] = None,
|
1002
1084
|
):
|
1003
1085
|
bs = len(seq_lens)
|
1004
1086
|
if spec_info is None:
|
@@ -1024,9 +1106,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
1024
1106
|
qo_indptr = qo_indptr[: bs + 1]
|
1025
1107
|
custom_mask = None
|
1026
1108
|
else:
|
1027
|
-
assert isinstance(spec_info,
|
1028
|
-
spec_info, EagleVerifyInput
|
1029
|
-
)
|
1109
|
+
assert isinstance(spec_info, SpecInput)
|
1030
1110
|
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
1031
1111
|
spec_info.generate_attn_arg_prefill(
|
1032
1112
|
req_pool_indices,
|
@@ -1069,6 +1149,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
1069
1149
|
kv_data_type=self.data_type,
|
1070
1150
|
custom_mask=custom_mask,
|
1071
1151
|
non_blocking=True,
|
1152
|
+
fixed_split_size=fixed_split_size,
|
1072
1153
|
)
|
1073
1154
|
|
1074
1155
|
|
@@ -1084,7 +1165,7 @@ class FlashInferMultiStepDraftBackend:
|
|
1084
1165
|
topk: int,
|
1085
1166
|
speculative_num_steps: int,
|
1086
1167
|
):
|
1087
|
-
from sglang.srt.speculative.
|
1168
|
+
from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
|
1088
1169
|
|
1089
1170
|
self.topk = topk
|
1090
1171
|
self.speculative_num_steps = speculative_num_steps
|
@@ -1148,7 +1229,7 @@ class FlashInferMultiStepDraftBackend:
|
|
1148
1229
|
)
|
1149
1230
|
|
1150
1231
|
assert forward_batch.spec_info is not None
|
1151
|
-
assert
|
1232
|
+
assert forward_batch.spec_info.is_draft_input()
|
1152
1233
|
|
1153
1234
|
# Copy the kv_indptr once to avoid multiple device-to-host copies in flashinfer's plan.
|
1154
1235
|
indptr_cpu_whole = self.kv_indptr[:, : bs + 1].cpu()
|
@@ -1276,166 +1357,3 @@ def should_use_tensor_core(
|
|
1276
1357
|
return gqa_group_size >= 4
|
1277
1358
|
else:
|
1278
1359
|
return False
|
1279
|
-
|
1280
|
-
|
1281
|
-
# Use as a fast path to override the indptr in flashinfer's plan function
|
1282
|
-
# This is used to remove some host-to-device copy overhead.
|
1283
|
-
global_override_indptr_cpu = None
|
1284
|
-
|
1285
|
-
|
1286
|
-
def fast_decode_plan(
|
1287
|
-
self,
|
1288
|
-
indptr: torch.Tensor,
|
1289
|
-
indices: torch.Tensor,
|
1290
|
-
last_page_len: torch.Tensor,
|
1291
|
-
num_qo_heads: int,
|
1292
|
-
num_kv_heads: int,
|
1293
|
-
head_dim: int,
|
1294
|
-
page_size: int,
|
1295
|
-
pos_encoding_mode: str = "NONE",
|
1296
|
-
window_left: int = -1,
|
1297
|
-
logits_soft_cap: Optional[float] = None,
|
1298
|
-
q_data_type: Optional[Union[str, torch.dtype]] = None,
|
1299
|
-
kv_data_type: Optional[Union[str, torch.dtype]] = None,
|
1300
|
-
data_type: Optional[Union[str, torch.dtype]] = None,
|
1301
|
-
sm_scale: Optional[float] = None,
|
1302
|
-
rope_scale: Optional[float] = None,
|
1303
|
-
rope_theta: Optional[float] = None,
|
1304
|
-
non_blocking: bool = True,
|
1305
|
-
) -> None:
|
1306
|
-
"""
|
1307
|
-
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
|
1308
|
-
Modifications:
|
1309
|
-
- Remove unnecessary device-to-device copy for the cuda graph buffers.
|
1310
|
-
- Remove unnecessary host-to-device copy for the metadata buffers.
|
1311
|
-
"""
|
1312
|
-
batch_size = len(last_page_len)
|
1313
|
-
if logits_soft_cap is None:
|
1314
|
-
logits_soft_cap = 0.0
|
1315
|
-
|
1316
|
-
# Handle data types consistently
|
1317
|
-
if data_type is not None:
|
1318
|
-
if q_data_type is None:
|
1319
|
-
q_data_type = data_type
|
1320
|
-
if kv_data_type is None:
|
1321
|
-
kv_data_type = data_type
|
1322
|
-
elif q_data_type is None:
|
1323
|
-
q_data_type = "float16"
|
1324
|
-
|
1325
|
-
if kv_data_type is None:
|
1326
|
-
kv_data_type = q_data_type
|
1327
|
-
|
1328
|
-
if self.use_tensor_cores:
|
1329
|
-
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
|
1330
|
-
|
1331
|
-
if self.is_cuda_graph_enabled:
|
1332
|
-
if batch_size != self._fixed_batch_size:
|
1333
|
-
raise ValueError(
|
1334
|
-
"The batch size should be fixed in cudagraph mode, the runtime batch size {} "
|
1335
|
-
" mismatches the batch size set during initialization {}".format(
|
1336
|
-
batch_size, self._fixed_batch_size
|
1337
|
-
)
|
1338
|
-
)
|
1339
|
-
if len(indices) > len(self._paged_kv_indices_buf):
|
1340
|
-
raise ValueError(
|
1341
|
-
"The size of indices should be less than or equal to the allocated buffer"
|
1342
|
-
)
|
1343
|
-
else:
|
1344
|
-
self._paged_kv_indptr_buf = indptr
|
1345
|
-
self._paged_kv_indices_buf = indices
|
1346
|
-
self._paged_kv_last_page_len_buf = last_page_len
|
1347
|
-
if self.use_tensor_cores:
|
1348
|
-
self._qo_indptr_buf = qo_indptr_host.to(
|
1349
|
-
self.device, non_blocking=non_blocking
|
1350
|
-
)
|
1351
|
-
|
1352
|
-
# Create empty tensors for dtype info if needed
|
1353
|
-
empty_q_data = torch.empty(
|
1354
|
-
0,
|
1355
|
-
dtype=(
|
1356
|
-
getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type
|
1357
|
-
),
|
1358
|
-
device=self.device,
|
1359
|
-
)
|
1360
|
-
|
1361
|
-
empty_kv_cache = torch.empty(
|
1362
|
-
0,
|
1363
|
-
dtype=(
|
1364
|
-
getattr(torch, kv_data_type)
|
1365
|
-
if isinstance(kv_data_type, str)
|
1366
|
-
else kv_data_type
|
1367
|
-
),
|
1368
|
-
device=self.device,
|
1369
|
-
)
|
1370
|
-
|
1371
|
-
indptr_host = (
|
1372
|
-
global_override_indptr_cpu
|
1373
|
-
if global_override_indptr_cpu is not None
|
1374
|
-
else indptr.cpu()
|
1375
|
-
)
|
1376
|
-
|
1377
|
-
with torch.cuda.device(self.device):
|
1378
|
-
|
1379
|
-
if self.use_tensor_cores:
|
1380
|
-
# ALSO convert last_page_len to CPU
|
1381
|
-
if page_size == 1:
|
1382
|
-
# When page size is 1, last_page_len is always 1.
|
1383
|
-
# Directly construct the host tensor rather than executing a device-to-host copy.
|
1384
|
-
last_page_len_host = torch.ones(
|
1385
|
-
(batch_size,), dtype=torch.int32, device="cpu"
|
1386
|
-
)
|
1387
|
-
else:
|
1388
|
-
last_page_len_host = last_page_len.cpu()
|
1389
|
-
|
1390
|
-
kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)
|
1391
|
-
|
1392
|
-
try:
|
1393
|
-
# Make sure we pass exactly 15 arguments for tensor core version
|
1394
|
-
self._plan_info = self._cached_module.plan(
|
1395
|
-
self._float_workspace_buffer,
|
1396
|
-
self._int_workspace_buffer,
|
1397
|
-
self._pin_memory_int_workspace_buffer,
|
1398
|
-
qo_indptr_host,
|
1399
|
-
indptr_host,
|
1400
|
-
kv_lens_arr_host,
|
1401
|
-
batch_size, # total_num_rows
|
1402
|
-
batch_size,
|
1403
|
-
num_qo_heads,
|
1404
|
-
num_kv_heads,
|
1405
|
-
page_size,
|
1406
|
-
self.is_cuda_graph_enabled,
|
1407
|
-
head_dim,
|
1408
|
-
head_dim,
|
1409
|
-
False, # causal
|
1410
|
-
)
|
1411
|
-
except Exception as e:
|
1412
|
-
raise RuntimeError(f"Error in standard plan: {e}")
|
1413
|
-
else:
|
1414
|
-
try:
|
1415
|
-
# Make sure we pass exactly 15 arguments for standard version
|
1416
|
-
self._plan_info = self._cached_module.plan(
|
1417
|
-
self._float_workspace_buffer,
|
1418
|
-
self._int_workspace_buffer,
|
1419
|
-
self._pin_memory_int_workspace_buffer,
|
1420
|
-
indptr_host,
|
1421
|
-
batch_size,
|
1422
|
-
num_qo_heads,
|
1423
|
-
num_kv_heads,
|
1424
|
-
page_size,
|
1425
|
-
self.is_cuda_graph_enabled,
|
1426
|
-
window_left,
|
1427
|
-
logits_soft_cap,
|
1428
|
-
head_dim,
|
1429
|
-
head_dim,
|
1430
|
-
empty_q_data,
|
1431
|
-
empty_kv_cache,
|
1432
|
-
)
|
1433
|
-
except Exception as e:
|
1434
|
-
raise RuntimeError(f"Error in standard plan: {e}")
|
1435
|
-
|
1436
|
-
self._pos_encoding_mode = pos_encoding_mode
|
1437
|
-
self._window_left = window_left
|
1438
|
-
self._logits_soft_cap = logits_soft_cap
|
1439
|
-
self._sm_scale = sm_scale
|
1440
|
-
self._rope_scale = rope_scale
|
1441
|
-
self._rope_theta = rope_theta
|
@@ -30,7 +30,7 @@ from sglang.srt.layers.attention.flashinfer_backend import (
|
|
30
30
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
31
31
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
32
32
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
33
|
-
from sglang.srt.speculative.
|
33
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
34
34
|
from sglang.srt.utils import (
|
35
35
|
is_flashinfer_available,
|
36
36
|
is_sm100_supported,
|
@@ -40,7 +40,7 @@ from sglang.srt.utils import (
|
|
40
40
|
if TYPE_CHECKING:
|
41
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
42
42
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
43
|
-
from sglang.srt.speculative.spec_info import
|
43
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
44
44
|
|
45
45
|
if is_flashinfer_available():
|
46
46
|
from flashinfer import (
|
@@ -361,7 +361,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
361
361
|
seq_lens: torch.Tensor,
|
362
362
|
encoder_lens: Optional[torch.Tensor],
|
363
363
|
forward_mode: ForwardMode,
|
364
|
-
spec_info: Optional[
|
364
|
+
spec_info: Optional[SpecInput],
|
365
365
|
):
|
366
366
|
if forward_mode.is_decode_or_idle():
|
367
367
|
decode_wrapper = BatchMLAPagedAttentionWrapper(
|
@@ -441,7 +441,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
441
441
|
seq_lens_sum: int,
|
442
442
|
encoder_lens: Optional[torch.Tensor],
|
443
443
|
forward_mode: ForwardMode,
|
444
|
-
spec_info: Optional[
|
444
|
+
spec_info: Optional[SpecInput],
|
445
445
|
seq_lens_cpu: Optional[torch.Tensor],
|
446
446
|
):
|
447
447
|
if forward_mode.is_decode_or_idle():
|
@@ -663,7 +663,7 @@ class FlashInferMLAIndicesUpdaterDecode:
|
|
663
663
|
seq_lens_sum: int,
|
664
664
|
decode_wrapper: BatchMLAPagedAttentionWrapper,
|
665
665
|
init_metadata_replay: bool = False,
|
666
|
-
spec_info: Optional[
|
666
|
+
spec_info: Optional[SpecInput] = None,
|
667
667
|
**fast_decode_kwargs,
|
668
668
|
):
|
669
669
|
decode_wrapper = decode_wrapper or self.decode_wrapper
|
@@ -688,7 +688,7 @@ class FlashInferMLAIndicesUpdaterDecode:
|
|
688
688
|
q_indptr: torch.Tensor,
|
689
689
|
kv_indptr: torch.Tensor,
|
690
690
|
init_metadata_replay: bool = False,
|
691
|
-
spec_info: Optional[
|
691
|
+
spec_info: Optional[SpecInput] = None,
|
692
692
|
**fast_decode_kwargs,
|
693
693
|
):
|
694
694
|
bs = len(req_pool_indices)
|
@@ -776,7 +776,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
776
776
|
prefix_lens: torch.Tensor,
|
777
777
|
prefill_wrapper_paged: BatchMLAPagedAttentionWrapper,
|
778
778
|
use_ragged: bool,
|
779
|
-
spec_info: Optional[
|
779
|
+
spec_info: Optional[SpecInput] = None,
|
780
780
|
):
|
781
781
|
if use_ragged:
|
782
782
|
paged_kernel_lens = prefix_lens
|
@@ -811,7 +811,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
811
811
|
kv_indptr: torch.Tensor,
|
812
812
|
qo_indptr: torch.Tensor,
|
813
813
|
use_ragged: bool,
|
814
|
-
spec_info: Optional[
|
814
|
+
spec_info: Optional[SpecInput] = None,
|
815
815
|
):
|
816
816
|
bs = len(seq_lens)
|
817
817
|
sm_scale = self.scaling
|
@@ -838,9 +838,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
838
838
|
qo_indptr = qo_indptr[: bs + 1]
|
839
839
|
custom_mask = None
|
840
840
|
else:
|
841
|
-
assert isinstance(spec_info,
|
842
|
-
spec_info, EagleVerifyInput
|
843
|
-
)
|
841
|
+
assert isinstance(spec_info, SpecInput)
|
844
842
|
# TODO: Support topk > 1 with custom mask
|
845
843
|
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
846
844
|
spec_info.generate_attn_arg_prefill(
|
@@ -894,7 +892,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
894
892
|
topk: int,
|
895
893
|
speculative_num_steps: int,
|
896
894
|
):
|
897
|
-
from sglang.srt.speculative.
|
895
|
+
from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
|
898
896
|
|
899
897
|
if topk > 1:
|
900
898
|
raise ValueError(
|
@@ -963,7 +961,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
963
961
|
)
|
964
962
|
|
965
963
|
assert forward_batch.spec_info is not None
|
966
|
-
assert
|
964
|
+
assert forward_batch.spec_info.is_draft_input()
|
967
965
|
|
968
966
|
for i in range(self.speculative_num_steps - 1):
|
969
967
|
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
@@ -983,8 +981,6 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
983
981
|
)
|
984
982
|
|
985
983
|
def call_fn(i, forward_batch):
|
986
|
-
assert forward_batch.spec_info is not None
|
987
|
-
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
988
984
|
forward_batch.spec_info.kv_indptr = (
|
989
985
|
forward_batch.spec_info.kv_indptr.clone()
|
990
986
|
)
|
@@ -19,7 +19,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
|
|
19
19
|
if TYPE_CHECKING:
|
20
20
|
from sglang.srt.layers.radix_attention import RadixAttention
|
21
21
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
22
|
-
from sglang.srt.speculative.spec_info import
|
22
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
23
23
|
|
24
24
|
|
25
25
|
# FlashMLA only supports pagesize=64
|
@@ -187,7 +187,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|
187
187
|
seq_lens: torch.Tensor,
|
188
188
|
encoder_lens: Optional[torch.Tensor],
|
189
189
|
forward_mode: ForwardMode,
|
190
|
-
spec_info: Optional[
|
190
|
+
spec_info: Optional[SpecInput],
|
191
191
|
):
|
192
192
|
if forward_mode.is_decode_or_idle():
|
193
193
|
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
|
@@ -201,9 +201,10 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|
201
201
|
self.req_to_token.stride(0),
|
202
202
|
self.cuda_graph_kv_indices.stride(0),
|
203
203
|
)
|
204
|
+
num_q_heads = self.num_q_heads * (self.num_draft_tokens or 1)
|
204
205
|
mla_metadata, num_splits = get_mla_metadata(
|
205
206
|
seq_lens.to(torch.int32),
|
206
|
-
|
207
|
+
num_q_heads,
|
207
208
|
1,
|
208
209
|
)
|
209
210
|
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
@@ -257,7 +258,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|
257
258
|
seq_lens_sum: int,
|
258
259
|
encoder_lens: Optional[torch.Tensor],
|
259
260
|
forward_mode: ForwardMode,
|
260
|
-
spec_info: Optional[
|
261
|
+
spec_info: Optional[SpecInput],
|
261
262
|
seq_lens_cpu: Optional[torch.Tensor],
|
262
263
|
):
|
263
264
|
|
@@ -275,9 +276,10 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|
275
276
|
self.req_to_token.stride(0),
|
276
277
|
self.cuda_graph_kv_indices.stride(0),
|
277
278
|
)
|
279
|
+
num_q_heads = self.num_q_heads * (self.num_draft_tokens or 1)
|
278
280
|
mla_metadata, num_splits = get_mla_metadata(
|
279
281
|
seq_lens.to(torch.int32),
|
280
|
-
|
282
|
+
num_q_heads,
|
281
283
|
1,
|
282
284
|
)
|
283
285
|
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|