sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,206 @@
|
|
1
|
+
import logging
|
2
|
+
|
3
|
+
logger = logging.getLogger(__name__)
|
4
|
+
|
5
|
+
ATTENTION_BACKENDS = {}
|
6
|
+
|
7
|
+
|
8
|
+
def register_attention_backend(name):
|
9
|
+
def decorator(fn):
|
10
|
+
ATTENTION_BACKENDS[name] = fn
|
11
|
+
return fn
|
12
|
+
|
13
|
+
return decorator
|
14
|
+
|
15
|
+
|
16
|
+
@register_attention_backend("flashinfer")
|
17
|
+
def create_flashinfer_backend(runner):
|
18
|
+
import torch
|
19
|
+
|
20
|
+
if not runner.use_mla_backend:
|
21
|
+
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
|
22
|
+
|
23
|
+
# Init streams
|
24
|
+
if runner.server_args.speculative_algorithm == "EAGLE":
|
25
|
+
if (
|
26
|
+
not hasattr(runner, "plan_stream_for_flashinfer")
|
27
|
+
or not runner.plan_stream_for_flashinfer
|
28
|
+
):
|
29
|
+
runner.plan_stream_for_flashinfer = torch.cuda.Stream()
|
30
|
+
return FlashInferAttnBackend(runner)
|
31
|
+
else:
|
32
|
+
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
33
|
+
FlashInferMLAAttnBackend,
|
34
|
+
)
|
35
|
+
|
36
|
+
return FlashInferMLAAttnBackend(runner)
|
37
|
+
|
38
|
+
|
39
|
+
@register_attention_backend("trtllm_mla")
|
40
|
+
def create_trtllm_mla_backend(runner):
|
41
|
+
if not runner.use_mla_backend:
|
42
|
+
raise ValueError("trtllm_mla backend can only be used with MLA models.")
|
43
|
+
from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
|
44
|
+
|
45
|
+
return TRTLLMMLABackend(runner)
|
46
|
+
|
47
|
+
|
48
|
+
@register_attention_backend("aiter")
|
49
|
+
def create_aiter_backend(runner):
|
50
|
+
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
|
51
|
+
|
52
|
+
return AiterAttnBackend(runner)
|
53
|
+
|
54
|
+
|
55
|
+
@register_attention_backend("wave")
|
56
|
+
def create_wave_backend(runner):
|
57
|
+
from sglang.srt.layers.attention.wave_backend import WaveAttnBackend
|
58
|
+
|
59
|
+
return WaveAttnBackend(runner)
|
60
|
+
|
61
|
+
|
62
|
+
@register_attention_backend("ascend")
|
63
|
+
def create_ascend_backend(runner):
|
64
|
+
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
|
65
|
+
|
66
|
+
return AscendAttnBackend(runner)
|
67
|
+
|
68
|
+
|
69
|
+
@register_attention_backend("nsa")
|
70
|
+
def create_nsa_backend(runner):
|
71
|
+
from sglang.srt.layers.attention.nsa_backend import NativeSparseAttnBackend
|
72
|
+
|
73
|
+
return NativeSparseAttnBackend(runner)
|
74
|
+
|
75
|
+
|
76
|
+
@register_attention_backend("triton")
|
77
|
+
def create_triton_backend(runner):
|
78
|
+
assert not runner.model_config.is_encoder_decoder, (
|
79
|
+
"Cross attention is not supported in the triton attention backend. "
|
80
|
+
"Please use `--attention-backend flashinfer`."
|
81
|
+
)
|
82
|
+
if runner.server_args.enable_double_sparsity:
|
83
|
+
from sglang.srt.layers.attention.double_sparsity_backend import (
|
84
|
+
DoubleSparseAttnBackend,
|
85
|
+
)
|
86
|
+
|
87
|
+
return DoubleSparseAttnBackend(runner)
|
88
|
+
else:
|
89
|
+
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
90
|
+
|
91
|
+
return TritonAttnBackend(runner)
|
92
|
+
|
93
|
+
|
94
|
+
@register_attention_backend("torch_native")
|
95
|
+
def create_torch_native_backend(runner):
|
96
|
+
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
97
|
+
|
98
|
+
return TorchNativeAttnBackend(runner)
|
99
|
+
|
100
|
+
|
101
|
+
@register_attention_backend("flex_attention")
|
102
|
+
def create_flex_attention_backend(runner):
|
103
|
+
from sglang.srt.layers.attention.torch_flex_backend import TorchFlexAttnBackend
|
104
|
+
|
105
|
+
return TorchFlexAttnBackend(runner)
|
106
|
+
|
107
|
+
|
108
|
+
@register_attention_backend("flashmla")
|
109
|
+
def create_flashmla_backend(runner):
|
110
|
+
from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
|
111
|
+
|
112
|
+
return FlashMLABackend(runner)
|
113
|
+
|
114
|
+
|
115
|
+
@register_attention_backend("fa3")
|
116
|
+
def create_flashattention_v3_backend(runner):
|
117
|
+
import torch
|
118
|
+
|
119
|
+
assert (
|
120
|
+
torch.cuda.get_device_capability()[0] == 8 and not runner.use_mla_backend
|
121
|
+
) or torch.cuda.get_device_capability()[0] == 9, (
|
122
|
+
"FlashAttention v3 Backend requires SM>=80 and SM<=90. "
|
123
|
+
"Please use `--attention-backend flashinfer`."
|
124
|
+
)
|
125
|
+
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
|
126
|
+
|
127
|
+
return FlashAttentionBackend(runner)
|
128
|
+
|
129
|
+
|
130
|
+
@register_attention_backend("fa4")
|
131
|
+
def create_flashattention_v4_backend(runner):
|
132
|
+
assert (
|
133
|
+
runner.use_mla_backend
|
134
|
+
), "FlashAttention v4 Support is at an early stage, only MLA model supported now"
|
135
|
+
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
|
136
|
+
|
137
|
+
return FlashAttentionBackend(runner, fa_impl_ver=4)
|
138
|
+
|
139
|
+
|
140
|
+
@register_attention_backend("cutlass_mla")
|
141
|
+
def create_cutlass_mla_backend(runner):
|
142
|
+
from sglang.srt.layers.attention.cutlass_mla_backend import CutlassMLABackend
|
143
|
+
|
144
|
+
return CutlassMLABackend(runner)
|
145
|
+
|
146
|
+
|
147
|
+
@register_attention_backend("trtllm_mha")
|
148
|
+
def create_trtllm_mha_backend(runner):
|
149
|
+
if runner.use_mla_backend:
|
150
|
+
raise ValueError("trtllm_mha backend can only be used with non-MLA models.")
|
151
|
+
from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend
|
152
|
+
|
153
|
+
return TRTLLMHAAttnBackend(runner)
|
154
|
+
|
155
|
+
|
156
|
+
@register_attention_backend("intel_amx")
|
157
|
+
def create_intel_amx_backend(runner):
|
158
|
+
from sglang.srt.layers.attention.intel_amx_backend import IntelAMXAttnBackend
|
159
|
+
|
160
|
+
return IntelAMXAttnBackend(runner)
|
161
|
+
|
162
|
+
|
163
|
+
@register_attention_backend("dual_chunk_flash_attn")
|
164
|
+
def create_dual_chunk_flash_attn_backend(runner):
|
165
|
+
from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
|
166
|
+
DualChunkFlashAttentionBackend,
|
167
|
+
)
|
168
|
+
|
169
|
+
return DualChunkFlashAttentionBackend(runner)
|
170
|
+
|
171
|
+
|
172
|
+
def attn_backend_wrapper(runner, full_attn_backend):
|
173
|
+
"""
|
174
|
+
Wrapper for special models like hybrid GDN, so we don't
|
175
|
+
need to change the code of the original attention backend.
|
176
|
+
"""
|
177
|
+
assert not (
|
178
|
+
runner.is_hybrid_gdn and runner.use_mla_backend
|
179
|
+
), "hybrid_gdn can only be used with non-MLA models."
|
180
|
+
|
181
|
+
# wrap for hybrid GDN models
|
182
|
+
if runner.is_hybrid_gdn:
|
183
|
+
from sglang.srt.utils import is_blackwell, is_npu
|
184
|
+
|
185
|
+
if is_blackwell():
|
186
|
+
assert (
|
187
|
+
runner.server_args.attention_backend == "triton"
|
188
|
+
or runner.server_args.attention_backend == "trtllm_mha"
|
189
|
+
), "triton or trtllm_mha backend are the only supported backends on Blackwell GPUs for hybrid GDN models, use --attention-backend triton or --attention-backend trtllm_mha to specify the backend."
|
190
|
+
if is_npu():
|
191
|
+
assert (
|
192
|
+
runner.server_args.attention_backend == "ascend"
|
193
|
+
), "ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend."
|
194
|
+
logger.info(f"Using hybrid linear attention backend for hybrid GDN models.")
|
195
|
+
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
|
196
|
+
HybridLinearAttnBackend,
|
197
|
+
MambaAttnBackend,
|
198
|
+
)
|
199
|
+
|
200
|
+
linear_attn_backend = MambaAttnBackend(runner)
|
201
|
+
full_attn_layers = runner.model_config.hf_config.full_attention_layer_ids
|
202
|
+
return HybridLinearAttnBackend(
|
203
|
+
full_attn_backend, linear_attn_backend, full_attn_layers
|
204
|
+
)
|
205
|
+
|
206
|
+
return full_attn_backend
|
@@ -6,9 +6,10 @@ from typing import TYPE_CHECKING, Optional, Union
|
|
6
6
|
import torch
|
7
7
|
|
8
8
|
if TYPE_CHECKING:
|
9
|
+
from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
|
9
10
|
from sglang.srt.layers.radix_attention import RadixAttention
|
10
11
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
11
|
-
from sglang.srt.speculative.
|
12
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
12
13
|
|
13
14
|
|
14
15
|
class AttentionBackend(ABC):
|
@@ -31,7 +32,7 @@ class AttentionBackend(ABC):
|
|
31
32
|
seq_lens: torch.Tensor,
|
32
33
|
encoder_lens: Optional[torch.Tensor],
|
33
34
|
forward_mode: ForwardMode,
|
34
|
-
spec_info: Optional[
|
35
|
+
spec_info: Optional[SpecInput],
|
35
36
|
):
|
36
37
|
"""Init the metadata for a forward pass for capturing a cuda graph."""
|
37
38
|
raise NotImplementedError()
|
@@ -44,7 +45,7 @@ class AttentionBackend(ABC):
|
|
44
45
|
seq_lens_sum: int,
|
45
46
|
encoder_lens: Optional[torch.Tensor],
|
46
47
|
forward_mode: ForwardMode,
|
47
|
-
spec_info: Optional[
|
48
|
+
spec_info: Optional[SpecInput],
|
48
49
|
seq_lens_cpu: Optional[torch.Tensor],
|
49
50
|
):
|
50
51
|
"""Init the metadata for a forward pass for replaying a cuda graph."""
|
@@ -115,3 +116,11 @@ class AttentionBackend(ABC):
|
|
115
116
|
def support_triton(self):
|
116
117
|
"""Check if the current backend supports triton."""
|
117
118
|
return True
|
119
|
+
|
120
|
+
def get_indexer_metadata(
|
121
|
+
self,
|
122
|
+
layer_id: int,
|
123
|
+
forward_batch: ForwardBatch,
|
124
|
+
) -> Optional[BaseIndexerMetadata]:
|
125
|
+
"""Get the indexer metadata. None means don't support indexer."""
|
126
|
+
return None
|
@@ -20,7 +20,7 @@ from sglang.srt.utils import is_cuda
|
|
20
20
|
if TYPE_CHECKING:
|
21
21
|
from sglang.srt.layers.radix_attention import RadixAttention
|
22
22
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
23
|
-
from sglang.srt.speculative.spec_info import
|
23
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
24
24
|
|
25
25
|
_is_cuda = is_cuda()
|
26
26
|
if _is_cuda:
|
@@ -151,7 +151,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
|
151
151
|
seq_lens: torch.Tensor,
|
152
152
|
encoder_lens: Optional[torch.Tensor],
|
153
153
|
forward_mode: ForwardMode,
|
154
|
-
spec_info: Optional[
|
154
|
+
spec_info: Optional[SpecInput],
|
155
155
|
):
|
156
156
|
if forward_mode.is_decode_or_idle():
|
157
157
|
if spec_info is None:
|
@@ -190,7 +190,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
|
190
190
|
seq_lens_sum: int,
|
191
191
|
encoder_lens: Optional[torch.Tensor],
|
192
192
|
forward_mode: ForwardMode,
|
193
|
-
spec_info: Optional[
|
193
|
+
spec_info: Optional[SpecInput],
|
194
194
|
seq_lens_cpu: Optional[torch.Tensor],
|
195
195
|
):
|
196
196
|
|
@@ -1537,7 +1537,7 @@ class DualChunkFlashAttentionBackend(AttentionBackend):
|
|
1537
1537
|
query_inter,
|
1538
1538
|
key_cache,
|
1539
1539
|
value_cache,
|
1540
|
-
block_table
|
1540
|
+
block_table,
|
1541
1541
|
decode_meta.seq_lens_inter,
|
1542
1542
|
softmax_scale,
|
1543
1543
|
causal=False,
|
@@ -74,8 +74,7 @@ def chunk_scaled_dot_kkt_fwd_kernel(
|
|
74
74
|
(1, 0),
|
75
75
|
)
|
76
76
|
b_k = tl.load(p_k, boundary_check=(0, 1))
|
77
|
-
|
78
|
-
b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k))
|
77
|
+
b_A += tl.dot(b_k, tl.trans(b_k))
|
79
78
|
|
80
79
|
if USE_G:
|
81
80
|
p_g = tl.make_block_ptr(
|
@@ -85,6 +84,7 @@ def chunk_scaled_dot_kkt_fwd_kernel(
|
|
85
84
|
b_g_diff = b_g[:, None] - b_g[None, :]
|
86
85
|
b_A = b_A * safe_exp(b_g_diff)
|
87
86
|
|
87
|
+
b_A *= b_beta[:, None]
|
88
88
|
b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0)
|
89
89
|
p_A = tl.make_block_ptr(
|
90
90
|
A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0)
|
@@ -86,8 +86,8 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
|
86
86
|
b_g = tl.load(p_g).to(tl.float32)
|
87
87
|
|
88
88
|
if USE_QK_L2NORM_IN_KERNEL:
|
89
|
-
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)
|
90
|
-
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)
|
89
|
+
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6))
|
90
|
+
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6))
|
91
91
|
b_q = b_q * scale
|
92
92
|
# [BK, BV]
|
93
93
|
b_h *= exp(b_g)
|
@@ -411,8 +411,8 @@ def fused_recurrent_gated_delta_rule_update_fwd_kernel(
|
|
411
411
|
b_g = tl.load(p_g).to(tl.float32)
|
412
412
|
|
413
413
|
if USE_QK_L2NORM_IN_KERNEL:
|
414
|
-
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)
|
415
|
-
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)
|
414
|
+
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6))
|
415
|
+
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6))
|
416
416
|
b_q = b_q * scale
|
417
417
|
# [BK, BV]
|
418
418
|
b_h *= exp(b_g)
|
@@ -119,8 +119,8 @@ def fused_sigmoid_gating_delta_rule_update_kernel(
|
|
119
119
|
|
120
120
|
# Apply L2 normalization if enabled
|
121
121
|
if USE_QK_L2NORM_IN_KERNEL:
|
122
|
-
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)
|
123
|
-
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)
|
122
|
+
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6))
|
123
|
+
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6))
|
124
124
|
|
125
125
|
b_q = b_q * scale
|
126
126
|
|
@@ -11,9 +11,8 @@ import triton.language as tl
|
|
11
11
|
from sglang.srt.configs.model_config import AttentionArch
|
12
12
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
13
13
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
14
|
-
from sglang.srt.mem_cache.memory_pool import SWAKVPool
|
15
14
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
16
|
-
from sglang.srt.speculative.
|
15
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
17
16
|
|
18
17
|
if TYPE_CHECKING:
|
19
18
|
from sglang.srt.layers.radix_attention import RadixAttention
|
@@ -305,6 +304,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
305
304
|
speculative_step_id=0,
|
306
305
|
topk=0,
|
307
306
|
speculative_num_steps=0,
|
307
|
+
fa_impl_ver=3,
|
308
308
|
):
|
309
309
|
super().__init__()
|
310
310
|
|
@@ -338,6 +338,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
338
338
|
)
|
339
339
|
self.speculative_step_id = speculative_step_id
|
340
340
|
|
341
|
+
self.fa_impl_ver = fa_impl_ver
|
342
|
+
|
341
343
|
# Local attention settings
|
342
344
|
self.attention_chunk_size = (
|
343
345
|
model_runner.attention_chunk_size
|
@@ -352,6 +354,13 @@ class FlashAttentionBackend(AttentionBackend):
|
|
352
354
|
self.sliding_window_size is not None and self.sliding_window_size > -1
|
353
355
|
)
|
354
356
|
|
357
|
+
# If num_splits == 0, we use a heuristic to automatically determine the number of splits.
|
358
|
+
# We set nums splits to 1 if deterministic inference is enabled.
|
359
|
+
# See https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/ for more details.
|
360
|
+
self.num_splits = (
|
361
|
+
1 if model_runner.server_args.enable_deterministic_inference else 0
|
362
|
+
)
|
363
|
+
|
355
364
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
356
365
|
"""Initialize forward metadata hence all layers in the forward pass can reuse it."""
|
357
366
|
metadata = FlashAttentionMetadata()
|
@@ -682,8 +691,13 @@ class FlashAttentionBackend(AttentionBackend):
|
|
682
691
|
k_descale, v_descale = None, None
|
683
692
|
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
684
693
|
# has corresponding quantization method so that layer.k_scale is not None,
|
685
|
-
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case
|
686
|
-
|
694
|
+
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
|
695
|
+
# 4) fa_impl_ver != 4 since fa4 does not currently support fp8 queries and keys.
|
696
|
+
if (
|
697
|
+
self.kv_cache_dtype_str != "auto"
|
698
|
+
and layer.head_dim <= 256
|
699
|
+
and self.fa_impl_ver != 4
|
700
|
+
):
|
687
701
|
if layer.k_scale is not None:
|
688
702
|
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
689
703
|
k_descale = layer.k_scale.expand(descale_shape)
|
@@ -712,6 +726,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
712
726
|
|
713
727
|
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
714
728
|
kwargs = {}
|
729
|
+
if self.fa_impl_ver != 3:
|
730
|
+
kwargs["ver"] = self.fa_impl_ver
|
715
731
|
if sinks is not None:
|
716
732
|
kwargs["sinks"] = sinks
|
717
733
|
|
@@ -738,6 +754,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
738
754
|
|
739
755
|
# Use Flash Attention for prefill
|
740
756
|
if not self.use_mla:
|
757
|
+
assert self.fa_impl_ver in [3], "Only FA3 support here"
|
741
758
|
# Do multi-head attention
|
742
759
|
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
|
743
760
|
layer.layer_id
|
@@ -770,6 +787,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
770
787
|
k_descale=k_descale,
|
771
788
|
v_descale=v_descale,
|
772
789
|
return_softmax_lse=use_cascade_attn,
|
790
|
+
num_splits=self.num_splits,
|
773
791
|
**kwargs,
|
774
792
|
)
|
775
793
|
|
@@ -791,6 +809,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
791
809
|
k_descale=k_descale,
|
792
810
|
v_descale=v_descale,
|
793
811
|
return_softmax_lse=True,
|
812
|
+
num_splits=self.num_splits,
|
794
813
|
**kwargs,
|
795
814
|
)
|
796
815
|
o, _ = merge_state_v2_wrapper(
|
@@ -830,6 +849,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
830
849
|
softmax_scale=layer.scaling,
|
831
850
|
causal=False,
|
832
851
|
return_softmax_lse=True,
|
852
|
+
**kwargs,
|
833
853
|
)
|
834
854
|
else:
|
835
855
|
# MHA for extend part of sequence without attending prefix kv cache
|
@@ -844,6 +864,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
844
864
|
softmax_scale=layer.scaling,
|
845
865
|
causal=True,
|
846
866
|
return_softmax_lse=forward_batch.mha_return_lse,
|
867
|
+
**kwargs,
|
847
868
|
)
|
848
869
|
if forward_batch.mha_return_lse:
|
849
870
|
output, lse, *rest = output
|
@@ -851,6 +872,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
851
872
|
return output, lse
|
852
873
|
return output
|
853
874
|
else:
|
875
|
+
assert self.fa_impl_ver in [3], "Only FA3 support here"
|
854
876
|
# Do absorbed multi-latent attention
|
855
877
|
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
856
878
|
layer.layer_id
|
@@ -892,6 +914,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
892
914
|
k_descale=k_descale,
|
893
915
|
v_descale=v_descale,
|
894
916
|
return_softmax_lse=use_cascade_attn,
|
917
|
+
num_splits=self.num_splits,
|
895
918
|
)
|
896
919
|
if use_cascade_attn:
|
897
920
|
o, softmax_lse, *rest = result
|
@@ -913,6 +936,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
913
936
|
k_descale=k_descale,
|
914
937
|
v_descale=v_descale,
|
915
938
|
return_softmax_lse=True,
|
939
|
+
num_splits=self.num_splits,
|
916
940
|
)
|
917
941
|
)
|
918
942
|
o, _ = merge_state_v2_wrapper(
|
@@ -939,6 +963,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
939
963
|
k_rope: Optional[torch.Tensor] = None,
|
940
964
|
sinks: Optional[torch.Tensor] = None,
|
941
965
|
) -> torch.Tensor:
|
966
|
+
assert self.fa_impl_ver in [3], "Only FA3 support decoding"
|
942
967
|
if k is not None:
|
943
968
|
assert v is not None
|
944
969
|
if save_kv_cache:
|
@@ -985,6 +1010,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
985
1010
|
|
986
1011
|
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
987
1012
|
kwargs = {}
|
1013
|
+
if self.fa_impl_ver != 3:
|
1014
|
+
kwargs["ver"] = self.fa_impl_ver
|
988
1015
|
if sinks is not None:
|
989
1016
|
kwargs["sinks"] = sinks
|
990
1017
|
|
@@ -1030,6 +1057,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1030
1057
|
softcap=layer.logit_cap,
|
1031
1058
|
k_descale=k_descale,
|
1032
1059
|
v_descale=v_descale,
|
1060
|
+
num_splits=self.num_splits,
|
1033
1061
|
**kwargs,
|
1034
1062
|
)
|
1035
1063
|
elif use_local_attn:
|
@@ -1049,6 +1077,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1049
1077
|
softcap=layer.logit_cap,
|
1050
1078
|
k_descale=k_descale,
|
1051
1079
|
v_descale=v_descale,
|
1080
|
+
num_splits=self.num_splits,
|
1052
1081
|
**kwargs,
|
1053
1082
|
)
|
1054
1083
|
else:
|
@@ -1077,6 +1106,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1077
1106
|
k_descale=k_descale,
|
1078
1107
|
v_descale=v_descale,
|
1079
1108
|
return_softmax_lse=use_cascade_attn,
|
1109
|
+
num_splits=self.num_splits,
|
1080
1110
|
**kwargs,
|
1081
1111
|
)
|
1082
1112
|
if use_cascade_attn:
|
@@ -1098,6 +1128,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1098
1128
|
k_descale=k_descale,
|
1099
1129
|
v_descale=v_descale,
|
1100
1130
|
return_softmax_lse=True,
|
1131
|
+
num_splits=self.num_splits,
|
1101
1132
|
**kwargs,
|
1102
1133
|
)
|
1103
1134
|
)
|
@@ -1153,6 +1184,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1153
1184
|
k_descale=k_descale,
|
1154
1185
|
v_descale=v_descale,
|
1155
1186
|
return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states
|
1187
|
+
num_splits=self.num_splits,
|
1156
1188
|
)
|
1157
1189
|
if use_cascade_attn:
|
1158
1190
|
o, softmax_lse, *rest = result
|
@@ -1173,6 +1205,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1173
1205
|
k_descale=k_descale,
|
1174
1206
|
v_descale=v_descale,
|
1175
1207
|
return_softmax_lse=True,
|
1208
|
+
num_splits=self.num_splits,
|
1176
1209
|
)
|
1177
1210
|
o, _ = merge_state_v2(
|
1178
1211
|
o,
|
@@ -1453,7 +1486,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1453
1486
|
seq_lens: torch.Tensor,
|
1454
1487
|
encoder_lens: Optional[torch.Tensor],
|
1455
1488
|
forward_mode: ForwardMode,
|
1456
|
-
spec_info: Optional[
|
1489
|
+
spec_info: Optional[SpecInput],
|
1457
1490
|
):
|
1458
1491
|
"""Initialize forward metadata for capturing CUDA graph."""
|
1459
1492
|
metadata = FlashAttentionMetadata()
|
@@ -1688,7 +1721,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1688
1721
|
seq_lens_sum: int,
|
1689
1722
|
encoder_lens: Optional[torch.Tensor],
|
1690
1723
|
forward_mode: ForwardMode,
|
1691
|
-
spec_info: Optional[
|
1724
|
+
spec_info: Optional[SpecInput],
|
1692
1725
|
seq_lens_cpu: Optional[torch.Tensor],
|
1693
1726
|
out_cache_loc: Optional[torch.Tensor] = None,
|
1694
1727
|
):
|
@@ -2306,7 +2339,7 @@ class FlashAttentionMultiStepBackend:
|
|
2306
2339
|
forward_batch: ForwardBatch,
|
2307
2340
|
):
|
2308
2341
|
assert forward_batch.spec_info is not None
|
2309
|
-
assert
|
2342
|
+
assert forward_batch.spec_info.is_draft_input()
|
2310
2343
|
|
2311
2344
|
for i in range(self.speculative_num_steps - 1):
|
2312
2345
|
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
|
@@ -2323,7 +2356,7 @@ class FlashAttentionMultiStepBackend:
|
|
2323
2356
|
self, forward_batch: ForwardBatch, bs: int
|
2324
2357
|
):
|
2325
2358
|
assert forward_batch.spec_info is not None
|
2326
|
-
assert
|
2359
|
+
assert forward_batch.spec_info.is_draft_input()
|
2327
2360
|
|
2328
2361
|
for i in range(self.speculative_num_steps - 1):
|
2329
2362
|
# TODO: incrementally update the metadata for the later steps,
|