sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +251 -26
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +63 -3
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +34 -19
- sglang/srt/entrypoints/openai/serving_completions.py +10 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +12 -0
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +250 -112
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +110 -49
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +43 -29
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -45
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +242 -278
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +13 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +160 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +90 -115
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +41 -477
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +24 -22
- sglang/srt/mem_cache/hiradix_cache.py +184 -101
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +324 -41
- sglang/srt/mem_cache/memory_pool_host.py +25 -18
- sglang/srt/mem_cache/radix_cache.py +5 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +189 -31
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +311 -50
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +90 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +297 -79
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/utils.py +37 -2
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,331 @@
|
|
1
|
+
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/utils.py
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
|
4
|
+
import contextlib
|
5
|
+
import functools
|
6
|
+
import logging
|
7
|
+
import os
|
8
|
+
import sys
|
9
|
+
from enum import Enum
|
10
|
+
from functools import lru_cache
|
11
|
+
from typing import Any, Callable, Dict, Literal, Optional, Tuple
|
12
|
+
|
13
|
+
import torch
|
14
|
+
import triton
|
15
|
+
from packaging import version
|
16
|
+
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
|
19
|
+
COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1"
|
20
|
+
FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1"
|
21
|
+
|
22
|
+
|
23
|
+
@lru_cache(maxsize=1)
|
24
|
+
def check_environments():
|
25
|
+
"""
|
26
|
+
Checks the current operating system, Triton version, and Python version,
|
27
|
+
issuing warnings if they don't meet recommendations.
|
28
|
+
This function's body only runs once due to lru_cache.
|
29
|
+
"""
|
30
|
+
# Check Operating System
|
31
|
+
if sys.platform == "win32":
|
32
|
+
logger.warning(
|
33
|
+
"Detected Windows operating system. Triton does not have an official Windows release, "
|
34
|
+
"thus FLA will not be adapted for Windows, and any potential errors will not be fixed. "
|
35
|
+
"Please consider using a Linux environment for compatibility."
|
36
|
+
)
|
37
|
+
|
38
|
+
triton_version = version.parse(triton.__version__)
|
39
|
+
required_triton_version = version.parse("3.2.0")
|
40
|
+
|
41
|
+
if triton_version < required_triton_version:
|
42
|
+
logger.warning(
|
43
|
+
f"Current Triton version {triton_version} is below the recommended 3.2.0 version. "
|
44
|
+
"Errors may occur and these issues will not be fixed. "
|
45
|
+
"Please consider upgrading Triton."
|
46
|
+
)
|
47
|
+
|
48
|
+
# Check Python version
|
49
|
+
py_version = version.parse(f"{sys.version_info.major}.{sys.version_info.minor}")
|
50
|
+
required_py_version = version.parse("3.11")
|
51
|
+
|
52
|
+
if py_version < required_py_version:
|
53
|
+
logger.warning(
|
54
|
+
f"Current Python version {py_version} is below the recommended 3.11 version. "
|
55
|
+
"It is recommended to upgrade to Python 3.11 or higher for the best experience."
|
56
|
+
)
|
57
|
+
|
58
|
+
return None
|
59
|
+
|
60
|
+
|
61
|
+
check_environments()
|
62
|
+
|
63
|
+
|
64
|
+
def get_abs_err(x, y):
|
65
|
+
return (x.detach() - y.detach()).flatten().abs().max().item()
|
66
|
+
|
67
|
+
|
68
|
+
def get_err_ratio(x, y):
|
69
|
+
err = (x.detach() - y.detach()).flatten().square().mean().sqrt().item()
|
70
|
+
base = (x.detach()).flatten().square().mean().sqrt().item()
|
71
|
+
return err / (base + 1e-8)
|
72
|
+
|
73
|
+
|
74
|
+
def assert_close(prefix, ref, tri, ratio, warning=False, err_atol=1e-6):
|
75
|
+
abs_atol = get_abs_err(ref, tri)
|
76
|
+
msg = f"{prefix} diff: {abs_atol:.6f} ratio: {get_err_ratio(ref, tri):.6f}"
|
77
|
+
logger.info(msg)
|
78
|
+
error_rate = get_err_ratio(ref, tri)
|
79
|
+
if abs_atol <= err_atol:
|
80
|
+
return
|
81
|
+
if warning or (FLA_CI_ENV and (error_rate < 0.01 or abs_atol <= 0.3)):
|
82
|
+
if error_rate > ratio:
|
83
|
+
import warnings
|
84
|
+
|
85
|
+
warnings.warn(msg)
|
86
|
+
else:
|
87
|
+
assert error_rate < ratio, msg
|
88
|
+
|
89
|
+
|
90
|
+
SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0"))
|
91
|
+
|
92
|
+
|
93
|
+
def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
|
94
|
+
"""
|
95
|
+
A decorator that caches the most recent results of a function with tensor inputs.
|
96
|
+
This decorator will store the output of the decorated function for the most recent set of input tensors.
|
97
|
+
The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed.
|
98
|
+
Args:
|
99
|
+
fn (Callable[..., torch.Tensor]):
|
100
|
+
The function to be decorated. It should take tensor inputs and return tensor outputs.
|
101
|
+
Returns:
|
102
|
+
Callable[..., torch.Tensor]:
|
103
|
+
A wrapped version of the input function with single-entry caching.
|
104
|
+
"""
|
105
|
+
|
106
|
+
cache_entries: Tuple[Optional[Tuple], Optional[Dict], Any] = []
|
107
|
+
cache_size = 4
|
108
|
+
|
109
|
+
@functools.wraps(fn)
|
110
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
111
|
+
nonlocal cache_entries, cache_size
|
112
|
+
for i, entry in enumerate(cache_entries):
|
113
|
+
last_args, last_kwargs, last_result = entry
|
114
|
+
if len(args) == len(last_args) and len(kwargs) == len(last_kwargs):
|
115
|
+
if all(a is b for a, b in zip(args, last_args)) and all(
|
116
|
+
k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()
|
117
|
+
):
|
118
|
+
cache_entries = (
|
119
|
+
cache_entries[:i]
|
120
|
+
+ cache_entries[i + 1 :]
|
121
|
+
+ [(args, kwargs, last_result)]
|
122
|
+
)
|
123
|
+
return last_result
|
124
|
+
|
125
|
+
result = fn(*args, **kwargs)
|
126
|
+
|
127
|
+
if len(cache_entries) >= cache_size:
|
128
|
+
cache_entries = cache_entries[1:]
|
129
|
+
cache_entries.append((args, kwargs, result))
|
130
|
+
return result
|
131
|
+
|
132
|
+
return wrapper
|
133
|
+
|
134
|
+
|
135
|
+
def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
|
136
|
+
"""
|
137
|
+
A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
|
138
|
+
"""
|
139
|
+
|
140
|
+
@functools.wraps(fn)
|
141
|
+
def wrapper(*args, **kwargs):
|
142
|
+
contiguous_args = (
|
143
|
+
i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args
|
144
|
+
)
|
145
|
+
contiguous_kwargs = {
|
146
|
+
k: (v if not isinstance(v, torch.Tensor) else v.contiguous())
|
147
|
+
for k, v in kwargs.items()
|
148
|
+
}
|
149
|
+
|
150
|
+
tensor = None
|
151
|
+
for arg in args:
|
152
|
+
if isinstance(arg, torch.Tensor):
|
153
|
+
tensor = arg
|
154
|
+
break
|
155
|
+
if tensor is None:
|
156
|
+
for value in kwargs.values():
|
157
|
+
if isinstance(value, torch.Tensor):
|
158
|
+
tensor = value
|
159
|
+
break
|
160
|
+
|
161
|
+
if tensor is not None:
|
162
|
+
ctx = custom_device_ctx(tensor.device.index)
|
163
|
+
else:
|
164
|
+
ctx = contextlib.nullcontext()
|
165
|
+
|
166
|
+
with ctx:
|
167
|
+
return fn(*contiguous_args, **contiguous_kwargs)
|
168
|
+
|
169
|
+
return wrapper
|
170
|
+
|
171
|
+
|
172
|
+
contiguous = input_guard
|
173
|
+
|
174
|
+
|
175
|
+
def require_version(version, hint):
|
176
|
+
"""
|
177
|
+
Perform a runtime check of the dependency versions, using the exact same syntax used by pip.
|
178
|
+
"""
|
179
|
+
|
180
|
+
def decorator(fn):
|
181
|
+
@functools.wraps(fn)
|
182
|
+
def wrapper(ctx, *args, **kwargs):
|
183
|
+
from transformers.utils.versions import require_version
|
184
|
+
|
185
|
+
require_version(version, hint)
|
186
|
+
return fn(
|
187
|
+
ctx,
|
188
|
+
*(
|
189
|
+
i if not isinstance(i, torch.Tensor) else i.contiguous()
|
190
|
+
for i in args
|
191
|
+
),
|
192
|
+
**{
|
193
|
+
k: (v if not isinstance(v, torch.Tensor) else v.contiguous())
|
194
|
+
for k, v in kwargs.items()
|
195
|
+
},
|
196
|
+
)
|
197
|
+
|
198
|
+
return wrapper
|
199
|
+
|
200
|
+
return decorator
|
201
|
+
|
202
|
+
|
203
|
+
def checkpoint(fn):
|
204
|
+
def wrapper(*args, **kwargs):
|
205
|
+
return torch.utils.checkpoint.checkpoint(fn, *args, **kwargs)
|
206
|
+
|
207
|
+
return wrapper
|
208
|
+
|
209
|
+
|
210
|
+
@lru_cache(maxsize=None)
|
211
|
+
def check_pytorch_version(version_s: str = "2.4") -> bool:
|
212
|
+
return version.parse(torch.__version__) >= version.parse(version_s)
|
213
|
+
|
214
|
+
|
215
|
+
def _cpu_device_warning():
|
216
|
+
import warnings
|
217
|
+
|
218
|
+
warnings.warn(
|
219
|
+
("Triton is not supported on current platform, roll back to CPU."), stacklevel=1
|
220
|
+
)
|
221
|
+
|
222
|
+
|
223
|
+
@lru_cache(maxsize=None)
|
224
|
+
def get_multiprocessor_count(tensor_idx: int = 0) -> int:
|
225
|
+
try:
|
226
|
+
return triton.runtime.driver.active.utils.get_device_properties(tensor_idx)[
|
227
|
+
"multiprocessor_count"
|
228
|
+
]
|
229
|
+
except BaseException:
|
230
|
+
_cpu_device_warning()
|
231
|
+
return -1
|
232
|
+
|
233
|
+
|
234
|
+
@lru_cache(maxsize=None)
|
235
|
+
def get_available_device() -> str:
|
236
|
+
try:
|
237
|
+
return triton.runtime.driver.active.get_current_target().backend
|
238
|
+
except BaseException:
|
239
|
+
_cpu_device_warning()
|
240
|
+
return "cpu"
|
241
|
+
|
242
|
+
|
243
|
+
@lru_cache(maxsize=None)
|
244
|
+
def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]:
|
245
|
+
device = get_available_device()
|
246
|
+
if device == "cuda":
|
247
|
+
return "nvidia"
|
248
|
+
elif device == "hip":
|
249
|
+
return "amd"
|
250
|
+
elif device == "xpu":
|
251
|
+
return "intel"
|
252
|
+
else:
|
253
|
+
return device
|
254
|
+
|
255
|
+
|
256
|
+
# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
|
257
|
+
# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
|
258
|
+
# Therefore, we need to check the triton backend to determine the actual GPU vendor.
|
259
|
+
device = get_available_device() if get_available_device() != "hip" else "cuda"
|
260
|
+
device_torch_lib = getattr(torch, device)
|
261
|
+
device_platform = _check_platform()
|
262
|
+
|
263
|
+
is_amd = device_platform == "amd"
|
264
|
+
is_intel = device_platform == "intel"
|
265
|
+
is_nvidia = device_platform == "nvidia"
|
266
|
+
is_intel_alchemist = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name(0)
|
267
|
+
is_nvidia_hopper = is_nvidia and (
|
268
|
+
"NVIDIA H" in torch.cuda.get_device_name(0)
|
269
|
+
or torch.cuda.get_device_capability()[0] >= 9
|
270
|
+
)
|
271
|
+
use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1"
|
272
|
+
|
273
|
+
# Nvidia Ampere or newer, haven't check AMD and intel yet.
|
274
|
+
is_tf32_supported = is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8
|
275
|
+
is_gather_supported = hasattr(triton.language, "gather")
|
276
|
+
|
277
|
+
|
278
|
+
def get_all_max_shared_mem():
|
279
|
+
try:
|
280
|
+
return [
|
281
|
+
triton.runtime.driver.active.utils.get_device_properties(i)[
|
282
|
+
"max_shared_mem"
|
283
|
+
]
|
284
|
+
for i in range(device_torch_lib.device_count())
|
285
|
+
]
|
286
|
+
except BaseException:
|
287
|
+
_cpu_device_warning()
|
288
|
+
return [-1]
|
289
|
+
|
290
|
+
|
291
|
+
class Backend(Enum):
|
292
|
+
ADA = 101376 # RTX 4090
|
293
|
+
AMPERE = 166912 # A100
|
294
|
+
HOPPER = 232448 # H100
|
295
|
+
DEFAULT = 102400 # Default
|
296
|
+
|
297
|
+
@classmethod
|
298
|
+
def get_shared_memory(cls, arch: str) -> int:
|
299
|
+
try:
|
300
|
+
return cls[arch.upper()].value
|
301
|
+
except KeyError:
|
302
|
+
return cls.DEFAULT.value
|
303
|
+
|
304
|
+
|
305
|
+
@lru_cache(maxsize=None)
|
306
|
+
def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
|
307
|
+
try:
|
308
|
+
device_shared_mem_list = get_all_max_shared_mem()
|
309
|
+
max_shared_memory = device_shared_mem_list[tensor_idx]
|
310
|
+
return max_shared_memory >= Backend.get_shared_memory(arch)
|
311
|
+
except Exception:
|
312
|
+
return False
|
313
|
+
|
314
|
+
|
315
|
+
if check_pytorch_version("2.4"):
|
316
|
+
device = "cuda" if device == "cpu" else device
|
317
|
+
autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=device)
|
318
|
+
autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=device)
|
319
|
+
|
320
|
+
def custom_device_ctx(index: int):
|
321
|
+
return device_torch_lib.device(index)
|
322
|
+
|
323
|
+
else:
|
324
|
+
assert (
|
325
|
+
device == "cuda"
|
326
|
+
), "Only cuda device is supported for PyTorch version < 2.4.0."
|
327
|
+
autocast_custom_fwd = device_torch_lib.amp.custom_fwd
|
328
|
+
autocast_custom_bwd = device_torch_lib.amp.custom_bwd
|
329
|
+
|
330
|
+
def custom_device_ctx(index: int):
|
331
|
+
return torch.cuda.device(index)
|
@@ -0,0 +1,158 @@
|
|
1
|
+
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gated_delta_rule/wy_fast.py
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
4
|
+
|
5
|
+
from typing import Optional, Tuple
|
6
|
+
|
7
|
+
import torch
|
8
|
+
import triton
|
9
|
+
import triton.language as tl
|
10
|
+
|
11
|
+
from sglang.srt.layers.attention.fla.index import prepare_chunk_indices
|
12
|
+
from sglang.srt.layers.attention.fla.op import safe_exp
|
13
|
+
from sglang.srt.layers.attention.fla.utils import check_shared_mem
|
14
|
+
|
15
|
+
|
16
|
+
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
17
|
+
# @triton.autotune(
|
18
|
+
# configs=[
|
19
|
+
# triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
20
|
+
# for num_warps in [2, 4, 8]
|
21
|
+
# for num_stages in [2, 3, 4]
|
22
|
+
# ],
|
23
|
+
# key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"],
|
24
|
+
# )
|
25
|
+
@triton.jit(do_not_specialize=["T"])
|
26
|
+
def recompute_w_u_fwd_kernel(
|
27
|
+
k,
|
28
|
+
v,
|
29
|
+
beta,
|
30
|
+
w,
|
31
|
+
u,
|
32
|
+
A,
|
33
|
+
g,
|
34
|
+
cu_seqlens,
|
35
|
+
chunk_indices,
|
36
|
+
T,
|
37
|
+
H: tl.constexpr,
|
38
|
+
Hg: tl.constexpr,
|
39
|
+
K: tl.constexpr,
|
40
|
+
V: tl.constexpr,
|
41
|
+
BT: tl.constexpr,
|
42
|
+
BK: tl.constexpr,
|
43
|
+
BV: tl.constexpr,
|
44
|
+
IS_VARLEN: tl.constexpr,
|
45
|
+
):
|
46
|
+
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
47
|
+
i_b, i_h = i_bh // H, i_bh % H
|
48
|
+
if IS_VARLEN:
|
49
|
+
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(
|
50
|
+
chunk_indices + i_t * 2 + 1
|
51
|
+
).to(tl.int32)
|
52
|
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(
|
53
|
+
cu_seqlens + i_n + 1
|
54
|
+
).to(tl.int32)
|
55
|
+
T = eos - bos
|
56
|
+
else:
|
57
|
+
bos, eos = i_b * T, i_b * T + T
|
58
|
+
p_beta = tl.make_block_ptr(
|
59
|
+
beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
|
60
|
+
)
|
61
|
+
p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
|
62
|
+
p_A = tl.make_block_ptr(
|
63
|
+
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
|
64
|
+
)
|
65
|
+
b_beta = tl.load(p_beta, boundary_check=(0,))
|
66
|
+
b_A = tl.load(p_A, boundary_check=(0, 1))
|
67
|
+
b_g = tl.exp(tl.load(p_g, boundary_check=(0,)))
|
68
|
+
|
69
|
+
for i_v in range(tl.cdiv(V, BV)):
|
70
|
+
p_v = tl.make_block_ptr(
|
71
|
+
v + (bos * H + i_h) * V,
|
72
|
+
(T, V),
|
73
|
+
(H * V, 1),
|
74
|
+
(i_t * BT, i_v * BV),
|
75
|
+
(BT, BV),
|
76
|
+
(1, 0),
|
77
|
+
)
|
78
|
+
p_u = tl.make_block_ptr(
|
79
|
+
u + (bos * H + i_h) * V,
|
80
|
+
(T, V),
|
81
|
+
(H * V, 1),
|
82
|
+
(i_t * BT, i_v * BV),
|
83
|
+
(BT, BV),
|
84
|
+
(1, 0),
|
85
|
+
)
|
86
|
+
b_v = tl.load(p_v, boundary_check=(0, 1))
|
87
|
+
b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
|
88
|
+
b_u = tl.dot(b_A, b_vb, allow_tf32=False)
|
89
|
+
tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
|
90
|
+
|
91
|
+
for i_k in range(tl.cdiv(K, BK)):
|
92
|
+
p_k = tl.make_block_ptr(
|
93
|
+
k + (bos * Hg + i_h // (H // Hg)) * K,
|
94
|
+
(T, K),
|
95
|
+
(Hg * K, 1),
|
96
|
+
(i_t * BT, i_k * BK),
|
97
|
+
(BT, BK),
|
98
|
+
(1, 0),
|
99
|
+
)
|
100
|
+
p_w = tl.make_block_ptr(
|
101
|
+
w + (bos * H + i_h) * K,
|
102
|
+
(T, K),
|
103
|
+
(H * K, 1),
|
104
|
+
(i_t * BT, i_k * BK),
|
105
|
+
(BT, BK),
|
106
|
+
(1, 0),
|
107
|
+
)
|
108
|
+
b_k = tl.load(p_k, boundary_check=(0, 1))
|
109
|
+
b_kb = (b_k * b_beta[:, None] * b_g[:, None]).to(b_k.dtype)
|
110
|
+
b_w = tl.dot(b_A, b_kb)
|
111
|
+
tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
|
112
|
+
|
113
|
+
|
114
|
+
def recompute_w_u_fwd(
|
115
|
+
k: torch.Tensor,
|
116
|
+
v: torch.Tensor,
|
117
|
+
beta: torch.Tensor,
|
118
|
+
g_cumsum: torch.Tensor,
|
119
|
+
A: torch.Tensor,
|
120
|
+
cu_seqlens: Optional[torch.LongTensor],
|
121
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
122
|
+
B, T, Hg, K, V = *k.shape, v.shape[-1]
|
123
|
+
H = v.shape[-2]
|
124
|
+
BT = A.shape[-1]
|
125
|
+
|
126
|
+
chunk_indices = (
|
127
|
+
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
128
|
+
)
|
129
|
+
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
130
|
+
BK = 64
|
131
|
+
BV = 64
|
132
|
+
u = torch.empty_like(v)
|
133
|
+
w = k.new_empty(B, T, H, K)
|
134
|
+
recompute_w_u_fwd_kernel[(NT, B * H)](
|
135
|
+
k=k,
|
136
|
+
v=v,
|
137
|
+
beta=beta,
|
138
|
+
w=w,
|
139
|
+
u=u,
|
140
|
+
A=A,
|
141
|
+
g=g_cumsum,
|
142
|
+
cu_seqlens=cu_seqlens,
|
143
|
+
chunk_indices=chunk_indices,
|
144
|
+
T=T,
|
145
|
+
H=H,
|
146
|
+
Hg=Hg,
|
147
|
+
K=K,
|
148
|
+
V=V,
|
149
|
+
BT=BT,
|
150
|
+
BK=BK,
|
151
|
+
BV=BV,
|
152
|
+
num_warps=4,
|
153
|
+
num_stages=3,
|
154
|
+
)
|
155
|
+
return w, u
|
156
|
+
|
157
|
+
|
158
|
+
fwd_recompute_w_u = recompute_w_u_fwd
|
@@ -501,8 +501,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
501
501
|
sm_scale=layer.scaling,
|
502
502
|
window_left=layer.sliding_window_size,
|
503
503
|
logits_soft_cap=logits_soft_cap,
|
504
|
-
|
505
|
-
|
504
|
+
# Must use _float to avoid device-to-host copy that breaks cuda graph capture.
|
505
|
+
k_scale=layer.k_scale_float,
|
506
|
+
v_scale=layer.v_scale_float,
|
506
507
|
)
|
507
508
|
else:
|
508
509
|
causal = True
|
@@ -580,8 +581,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
580
581
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
581
582
|
sm_scale=layer.scaling,
|
582
583
|
logits_soft_cap=layer.logit_cap,
|
583
|
-
|
584
|
-
|
584
|
+
# Must use _float to avoid device-to-host copy that breaks cuda graph capture.
|
585
|
+
k_scale=layer.k_scale_float,
|
586
|
+
v_scale=layer.v_scale_float,
|
585
587
|
)
|
586
588
|
|
587
589
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
@@ -96,6 +96,7 @@ class FlashInferMhaChunkKVRunner:
|
|
96
96
|
def update_wrapper(
|
97
97
|
self,
|
98
98
|
forward_batch: ForwardBatch,
|
99
|
+
disable_flashinfer_ragged: bool = False,
|
99
100
|
):
|
100
101
|
assert forward_batch.num_prefix_chunks is not None
|
101
102
|
num_prefix_chunks = forward_batch.num_prefix_chunks
|
@@ -128,16 +129,17 @@ class FlashInferMhaChunkKVRunner:
|
|
128
129
|
causal=False,
|
129
130
|
)
|
130
131
|
# ragged prefill
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
132
|
+
if not disable_flashinfer_ragged:
|
133
|
+
self.ragged_wrapper.begin_forward(
|
134
|
+
qo_indptr=qo_indptr,
|
135
|
+
kv_indptr=qo_indptr,
|
136
|
+
num_qo_heads=self.num_local_heads,
|
137
|
+
num_kv_heads=self.num_local_heads,
|
138
|
+
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
|
139
|
+
head_dim_vo=self.v_head_dim,
|
140
|
+
q_data_type=self.q_data_type,
|
141
|
+
causal=True,
|
142
|
+
)
|
141
143
|
|
142
144
|
def forward(
|
143
145
|
self,
|
@@ -491,9 +493,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
491
493
|
def get_cuda_graph_seq_len_fill_value(self):
|
492
494
|
return 1
|
493
495
|
|
494
|
-
def init_mha_chunk_metadata(
|
496
|
+
def init_mha_chunk_metadata(
|
497
|
+
self, forward_batch: ForwardBatch, disable_flashinfer_ragged: bool = False
|
498
|
+
):
|
495
499
|
"""Init the metadata for a forward pass."""
|
496
|
-
self.mha_chunk_kv_cache.update_wrapper(forward_batch)
|
500
|
+
self.mha_chunk_kv_cache.update_wrapper(forward_batch, disable_flashinfer_ragged)
|
497
501
|
|
498
502
|
def forward_extend(
|
499
503
|
self,
|
@@ -5,6 +5,7 @@ import torch
|
|
5
5
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
6
6
|
from sglang.srt.layers.radix_attention import RadixAttention
|
7
7
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
8
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
8
9
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
9
10
|
|
10
11
|
|
@@ -12,19 +13,54 @@ class HybridAttnBackend(AttentionBackend):
|
|
12
13
|
"""Support different backends for prefill and decode."""
|
13
14
|
|
14
15
|
def __init__(
|
15
|
-
self,
|
16
|
+
self,
|
17
|
+
model_runner: ModelRunner,
|
18
|
+
prefill_backend: AttentionBackend,
|
19
|
+
decode_backend: AttentionBackend,
|
16
20
|
):
|
21
|
+
self.model_runner = model_runner
|
17
22
|
self.prefill_backend = prefill_backend
|
18
23
|
self.decode_backend = decode_backend
|
19
24
|
|
20
|
-
def
|
21
|
-
|
22
|
-
|
25
|
+
def _select_backend(self, forward_mode: ForwardMode) -> AttentionBackend:
|
26
|
+
"""
|
27
|
+
Select the appropriate attention backend based on the forward mode.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
forward_mode: The current forward mode indicating the operation type
|
31
|
+
|
32
|
+
Returns:
|
33
|
+
The selected attention backend (prefill or decode)
|
34
|
+
|
35
|
+
Note:
|
36
|
+
- decode_or_idle: Always uses decode backend
|
37
|
+
- target_verify or draft_extend: Uses decode backend if speculative_attention_mode is "decode", otherwise prefill backend
|
38
|
+
- prefill: Always uses prefill backend
|
39
|
+
"""
|
40
|
+
if forward_mode.is_decode_or_idle():
|
41
|
+
return self.decode_backend
|
42
|
+
elif forward_mode.is_target_verify() or forward_mode.is_draft_extend():
|
43
|
+
return (
|
44
|
+
self.decode_backend
|
45
|
+
if self.model_runner.server_args.speculative_attention_mode == "decode"
|
46
|
+
else self.prefill_backend
|
47
|
+
)
|
23
48
|
else:
|
24
|
-
self.prefill_backend
|
49
|
+
return self.prefill_backend
|
50
|
+
|
51
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
52
|
+
backend = self._select_backend(forward_batch.forward_mode)
|
53
|
+
backend.init_forward_metadata(forward_batch)
|
25
54
|
|
26
55
|
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
27
56
|
self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens)
|
57
|
+
if (
|
58
|
+
self.model_runner.server_args.speculative_algorithm is not None
|
59
|
+
and self.model_runner.server_args.speculative_attention_mode == "prefill"
|
60
|
+
):
|
61
|
+
# When speculative decoding is enabled, we need to initialize the backend
|
62
|
+
# that will be used for target_verify.
|
63
|
+
self.prefill_backend.init_cuda_graph_state(max_bs, max_num_tokens)
|
28
64
|
|
29
65
|
def init_forward_metadata_capture_cuda_graph(
|
30
66
|
self,
|
@@ -36,7 +72,8 @@ class HybridAttnBackend(AttentionBackend):
|
|
36
72
|
forward_mode: ForwardMode,
|
37
73
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
38
74
|
):
|
39
|
-
self.
|
75
|
+
backend = self._select_backend(forward_mode)
|
76
|
+
backend.init_forward_metadata_capture_cuda_graph(
|
40
77
|
bs,
|
41
78
|
num_tokens,
|
42
79
|
req_pool_indices,
|
@@ -57,7 +94,8 @@ class HybridAttnBackend(AttentionBackend):
|
|
57
94
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
58
95
|
seq_lens_cpu: Optional[torch.Tensor],
|
59
96
|
):
|
60
|
-
self.
|
97
|
+
backend = self._select_backend(forward_mode)
|
98
|
+
backend.init_forward_metadata_replay_cuda_graph(
|
61
99
|
bs,
|
62
100
|
req_pool_indices,
|
63
101
|
seq_lens,
|
@@ -95,6 +133,7 @@ class HybridAttnBackend(AttentionBackend):
|
|
95
133
|
save_kv_cache: bool = True,
|
96
134
|
**kwargs,
|
97
135
|
):
|
98
|
-
|
136
|
+
backend = self._select_backend(forward_batch.forward_mode)
|
137
|
+
return backend.forward_extend(
|
99
138
|
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
100
139
|
)
|