sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +89 -54
- sglang/bench_serving.py +437 -40
- sglang/lang/interpreter.py +1 -1
- sglang/profiler.py +0 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +90 -27
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +82 -26
- sglang/srt/entrypoints/openai/serving_completions.py +25 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +28 -7
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +381 -136
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +11 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -8
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_moe.py +0 -8
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +111 -56
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
- sglang/srt/layers/quantization/fp8.py +78 -48
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +45 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +93 -68
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +396 -365
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +18 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +190 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +148 -122
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +77 -480
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +53 -40
- sglang/srt/mem_cache/hiradix_cache.py +196 -104
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +395 -53
- sglang/srt/mem_cache/memory_pool_host.py +27 -19
- sglang/srt/mem_cache/radix_cache.py +6 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +190 -32
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +323 -53
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +7 -19
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +91 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{conversation.py → parser/conversation.py} +38 -5
- sglang/srt/parser/harmony_parser.py +588 -0
- sglang/srt/parser/reasoning_parser.py +309 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +307 -80
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +96 -7
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- sglang/srt/reasoning_parser.py +0 -553
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,331 @@
|
|
1
|
+
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/utils.py
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
|
4
|
+
import contextlib
|
5
|
+
import functools
|
6
|
+
import logging
|
7
|
+
import os
|
8
|
+
import sys
|
9
|
+
from enum import Enum
|
10
|
+
from functools import lru_cache
|
11
|
+
from typing import Any, Callable, Dict, Literal, Optional, Tuple
|
12
|
+
|
13
|
+
import torch
|
14
|
+
import triton
|
15
|
+
from packaging import version
|
16
|
+
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
|
19
|
+
COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1"
|
20
|
+
FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1"
|
21
|
+
|
22
|
+
|
23
|
+
@lru_cache(maxsize=1)
|
24
|
+
def check_environments():
|
25
|
+
"""
|
26
|
+
Checks the current operating system, Triton version, and Python version,
|
27
|
+
issuing warnings if they don't meet recommendations.
|
28
|
+
This function's body only runs once due to lru_cache.
|
29
|
+
"""
|
30
|
+
# Check Operating System
|
31
|
+
if sys.platform == "win32":
|
32
|
+
logger.warning(
|
33
|
+
"Detected Windows operating system. Triton does not have an official Windows release, "
|
34
|
+
"thus FLA will not be adapted for Windows, and any potential errors will not be fixed. "
|
35
|
+
"Please consider using a Linux environment for compatibility."
|
36
|
+
)
|
37
|
+
|
38
|
+
triton_version = version.parse(triton.__version__)
|
39
|
+
required_triton_version = version.parse("3.2.0")
|
40
|
+
|
41
|
+
if triton_version < required_triton_version:
|
42
|
+
logger.warning(
|
43
|
+
f"Current Triton version {triton_version} is below the recommended 3.2.0 version. "
|
44
|
+
"Errors may occur and these issues will not be fixed. "
|
45
|
+
"Please consider upgrading Triton."
|
46
|
+
)
|
47
|
+
|
48
|
+
# Check Python version
|
49
|
+
py_version = version.parse(f"{sys.version_info.major}.{sys.version_info.minor}")
|
50
|
+
required_py_version = version.parse("3.11")
|
51
|
+
|
52
|
+
if py_version < required_py_version:
|
53
|
+
logger.warning(
|
54
|
+
f"Current Python version {py_version} is below the recommended 3.11 version. "
|
55
|
+
"It is recommended to upgrade to Python 3.11 or higher for the best experience."
|
56
|
+
)
|
57
|
+
|
58
|
+
return None
|
59
|
+
|
60
|
+
|
61
|
+
check_environments()
|
62
|
+
|
63
|
+
|
64
|
+
def get_abs_err(x, y):
|
65
|
+
return (x.detach() - y.detach()).flatten().abs().max().item()
|
66
|
+
|
67
|
+
|
68
|
+
def get_err_ratio(x, y):
|
69
|
+
err = (x.detach() - y.detach()).flatten().square().mean().sqrt().item()
|
70
|
+
base = (x.detach()).flatten().square().mean().sqrt().item()
|
71
|
+
return err / (base + 1e-8)
|
72
|
+
|
73
|
+
|
74
|
+
def assert_close(prefix, ref, tri, ratio, warning=False, err_atol=1e-6):
|
75
|
+
abs_atol = get_abs_err(ref, tri)
|
76
|
+
msg = f"{prefix} diff: {abs_atol:.6f} ratio: {get_err_ratio(ref, tri):.6f}"
|
77
|
+
logger.info(msg)
|
78
|
+
error_rate = get_err_ratio(ref, tri)
|
79
|
+
if abs_atol <= err_atol:
|
80
|
+
return
|
81
|
+
if warning or (FLA_CI_ENV and (error_rate < 0.01 or abs_atol <= 0.3)):
|
82
|
+
if error_rate > ratio:
|
83
|
+
import warnings
|
84
|
+
|
85
|
+
warnings.warn(msg)
|
86
|
+
else:
|
87
|
+
assert error_rate < ratio, msg
|
88
|
+
|
89
|
+
|
90
|
+
SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0"))
|
91
|
+
|
92
|
+
|
93
|
+
def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
|
94
|
+
"""
|
95
|
+
A decorator that caches the most recent results of a function with tensor inputs.
|
96
|
+
This decorator will store the output of the decorated function for the most recent set of input tensors.
|
97
|
+
The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed.
|
98
|
+
Args:
|
99
|
+
fn (Callable[..., torch.Tensor]):
|
100
|
+
The function to be decorated. It should take tensor inputs and return tensor outputs.
|
101
|
+
Returns:
|
102
|
+
Callable[..., torch.Tensor]:
|
103
|
+
A wrapped version of the input function with single-entry caching.
|
104
|
+
"""
|
105
|
+
|
106
|
+
cache_entries: Tuple[Optional[Tuple], Optional[Dict], Any] = []
|
107
|
+
cache_size = 4
|
108
|
+
|
109
|
+
@functools.wraps(fn)
|
110
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
111
|
+
nonlocal cache_entries, cache_size
|
112
|
+
for i, entry in enumerate(cache_entries):
|
113
|
+
last_args, last_kwargs, last_result = entry
|
114
|
+
if len(args) == len(last_args) and len(kwargs) == len(last_kwargs):
|
115
|
+
if all(a is b for a, b in zip(args, last_args)) and all(
|
116
|
+
k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()
|
117
|
+
):
|
118
|
+
cache_entries = (
|
119
|
+
cache_entries[:i]
|
120
|
+
+ cache_entries[i + 1 :]
|
121
|
+
+ [(args, kwargs, last_result)]
|
122
|
+
)
|
123
|
+
return last_result
|
124
|
+
|
125
|
+
result = fn(*args, **kwargs)
|
126
|
+
|
127
|
+
if len(cache_entries) >= cache_size:
|
128
|
+
cache_entries = cache_entries[1:]
|
129
|
+
cache_entries.append((args, kwargs, result))
|
130
|
+
return result
|
131
|
+
|
132
|
+
return wrapper
|
133
|
+
|
134
|
+
|
135
|
+
def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
|
136
|
+
"""
|
137
|
+
A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
|
138
|
+
"""
|
139
|
+
|
140
|
+
@functools.wraps(fn)
|
141
|
+
def wrapper(*args, **kwargs):
|
142
|
+
contiguous_args = (
|
143
|
+
i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args
|
144
|
+
)
|
145
|
+
contiguous_kwargs = {
|
146
|
+
k: (v if not isinstance(v, torch.Tensor) else v.contiguous())
|
147
|
+
for k, v in kwargs.items()
|
148
|
+
}
|
149
|
+
|
150
|
+
tensor = None
|
151
|
+
for arg in args:
|
152
|
+
if isinstance(arg, torch.Tensor):
|
153
|
+
tensor = arg
|
154
|
+
break
|
155
|
+
if tensor is None:
|
156
|
+
for value in kwargs.values():
|
157
|
+
if isinstance(value, torch.Tensor):
|
158
|
+
tensor = value
|
159
|
+
break
|
160
|
+
|
161
|
+
if tensor is not None:
|
162
|
+
ctx = custom_device_ctx(tensor.device.index)
|
163
|
+
else:
|
164
|
+
ctx = contextlib.nullcontext()
|
165
|
+
|
166
|
+
with ctx:
|
167
|
+
return fn(*contiguous_args, **contiguous_kwargs)
|
168
|
+
|
169
|
+
return wrapper
|
170
|
+
|
171
|
+
|
172
|
+
contiguous = input_guard
|
173
|
+
|
174
|
+
|
175
|
+
def require_version(version, hint):
|
176
|
+
"""
|
177
|
+
Perform a runtime check of the dependency versions, using the exact same syntax used by pip.
|
178
|
+
"""
|
179
|
+
|
180
|
+
def decorator(fn):
|
181
|
+
@functools.wraps(fn)
|
182
|
+
def wrapper(ctx, *args, **kwargs):
|
183
|
+
from transformers.utils.versions import require_version
|
184
|
+
|
185
|
+
require_version(version, hint)
|
186
|
+
return fn(
|
187
|
+
ctx,
|
188
|
+
*(
|
189
|
+
i if not isinstance(i, torch.Tensor) else i.contiguous()
|
190
|
+
for i in args
|
191
|
+
),
|
192
|
+
**{
|
193
|
+
k: (v if not isinstance(v, torch.Tensor) else v.contiguous())
|
194
|
+
for k, v in kwargs.items()
|
195
|
+
},
|
196
|
+
)
|
197
|
+
|
198
|
+
return wrapper
|
199
|
+
|
200
|
+
return decorator
|
201
|
+
|
202
|
+
|
203
|
+
def checkpoint(fn):
|
204
|
+
def wrapper(*args, **kwargs):
|
205
|
+
return torch.utils.checkpoint.checkpoint(fn, *args, **kwargs)
|
206
|
+
|
207
|
+
return wrapper
|
208
|
+
|
209
|
+
|
210
|
+
@lru_cache(maxsize=None)
|
211
|
+
def check_pytorch_version(version_s: str = "2.4") -> bool:
|
212
|
+
return version.parse(torch.__version__) >= version.parse(version_s)
|
213
|
+
|
214
|
+
|
215
|
+
def _cpu_device_warning():
|
216
|
+
import warnings
|
217
|
+
|
218
|
+
warnings.warn(
|
219
|
+
("Triton is not supported on current platform, roll back to CPU."), stacklevel=1
|
220
|
+
)
|
221
|
+
|
222
|
+
|
223
|
+
@lru_cache(maxsize=None)
|
224
|
+
def get_multiprocessor_count(tensor_idx: int = 0) -> int:
|
225
|
+
try:
|
226
|
+
return triton.runtime.driver.active.utils.get_device_properties(tensor_idx)[
|
227
|
+
"multiprocessor_count"
|
228
|
+
]
|
229
|
+
except BaseException:
|
230
|
+
_cpu_device_warning()
|
231
|
+
return -1
|
232
|
+
|
233
|
+
|
234
|
+
@lru_cache(maxsize=None)
|
235
|
+
def get_available_device() -> str:
|
236
|
+
try:
|
237
|
+
return triton.runtime.driver.active.get_current_target().backend
|
238
|
+
except BaseException:
|
239
|
+
_cpu_device_warning()
|
240
|
+
return "cpu"
|
241
|
+
|
242
|
+
|
243
|
+
@lru_cache(maxsize=None)
|
244
|
+
def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]:
|
245
|
+
device = get_available_device()
|
246
|
+
if device == "cuda":
|
247
|
+
return "nvidia"
|
248
|
+
elif device == "hip":
|
249
|
+
return "amd"
|
250
|
+
elif device == "xpu":
|
251
|
+
return "intel"
|
252
|
+
else:
|
253
|
+
return device
|
254
|
+
|
255
|
+
|
256
|
+
# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
|
257
|
+
# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
|
258
|
+
# Therefore, we need to check the triton backend to determine the actual GPU vendor.
|
259
|
+
device = get_available_device() if get_available_device() != "hip" else "cuda"
|
260
|
+
device_torch_lib = getattr(torch, device)
|
261
|
+
device_platform = _check_platform()
|
262
|
+
|
263
|
+
is_amd = device_platform == "amd"
|
264
|
+
is_intel = device_platform == "intel"
|
265
|
+
is_nvidia = device_platform == "nvidia"
|
266
|
+
is_intel_alchemist = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name(0)
|
267
|
+
is_nvidia_hopper = is_nvidia and (
|
268
|
+
"NVIDIA H" in torch.cuda.get_device_name(0)
|
269
|
+
or torch.cuda.get_device_capability()[0] >= 9
|
270
|
+
)
|
271
|
+
use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1"
|
272
|
+
|
273
|
+
# Nvidia Ampere or newer, haven't check AMD and intel yet.
|
274
|
+
is_tf32_supported = is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8
|
275
|
+
is_gather_supported = hasattr(triton.language, "gather")
|
276
|
+
|
277
|
+
|
278
|
+
def get_all_max_shared_mem():
|
279
|
+
try:
|
280
|
+
return [
|
281
|
+
triton.runtime.driver.active.utils.get_device_properties(i)[
|
282
|
+
"max_shared_mem"
|
283
|
+
]
|
284
|
+
for i in range(device_torch_lib.device_count())
|
285
|
+
]
|
286
|
+
except BaseException:
|
287
|
+
_cpu_device_warning()
|
288
|
+
return [-1]
|
289
|
+
|
290
|
+
|
291
|
+
class Backend(Enum):
|
292
|
+
ADA = 101376 # RTX 4090
|
293
|
+
AMPERE = 166912 # A100
|
294
|
+
HOPPER = 232448 # H100
|
295
|
+
DEFAULT = 102400 # Default
|
296
|
+
|
297
|
+
@classmethod
|
298
|
+
def get_shared_memory(cls, arch: str) -> int:
|
299
|
+
try:
|
300
|
+
return cls[arch.upper()].value
|
301
|
+
except KeyError:
|
302
|
+
return cls.DEFAULT.value
|
303
|
+
|
304
|
+
|
305
|
+
@lru_cache(maxsize=None)
|
306
|
+
def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
|
307
|
+
try:
|
308
|
+
device_shared_mem_list = get_all_max_shared_mem()
|
309
|
+
max_shared_memory = device_shared_mem_list[tensor_idx]
|
310
|
+
return max_shared_memory >= Backend.get_shared_memory(arch)
|
311
|
+
except Exception:
|
312
|
+
return False
|
313
|
+
|
314
|
+
|
315
|
+
if check_pytorch_version("2.4"):
|
316
|
+
device = "cuda" if device == "cpu" else device
|
317
|
+
autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=device)
|
318
|
+
autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=device)
|
319
|
+
|
320
|
+
def custom_device_ctx(index: int):
|
321
|
+
return device_torch_lib.device(index)
|
322
|
+
|
323
|
+
else:
|
324
|
+
assert (
|
325
|
+
device == "cuda"
|
326
|
+
), "Only cuda device is supported for PyTorch version < 2.4.0."
|
327
|
+
autocast_custom_fwd = device_torch_lib.amp.custom_fwd
|
328
|
+
autocast_custom_bwd = device_torch_lib.amp.custom_bwd
|
329
|
+
|
330
|
+
def custom_device_ctx(index: int):
|
331
|
+
return torch.cuda.device(index)
|
@@ -0,0 +1,158 @@
|
|
1
|
+
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gated_delta_rule/wy_fast.py
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
4
|
+
|
5
|
+
from typing import Optional, Tuple
|
6
|
+
|
7
|
+
import torch
|
8
|
+
import triton
|
9
|
+
import triton.language as tl
|
10
|
+
|
11
|
+
from sglang.srt.layers.attention.fla.index import prepare_chunk_indices
|
12
|
+
from sglang.srt.layers.attention.fla.op import safe_exp
|
13
|
+
from sglang.srt.layers.attention.fla.utils import check_shared_mem
|
14
|
+
|
15
|
+
|
16
|
+
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
17
|
+
# @triton.autotune(
|
18
|
+
# configs=[
|
19
|
+
# triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
20
|
+
# for num_warps in [2, 4, 8]
|
21
|
+
# for num_stages in [2, 3, 4]
|
22
|
+
# ],
|
23
|
+
# key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"],
|
24
|
+
# )
|
25
|
+
@triton.jit(do_not_specialize=["T"])
|
26
|
+
def recompute_w_u_fwd_kernel(
|
27
|
+
k,
|
28
|
+
v,
|
29
|
+
beta,
|
30
|
+
w,
|
31
|
+
u,
|
32
|
+
A,
|
33
|
+
g,
|
34
|
+
cu_seqlens,
|
35
|
+
chunk_indices,
|
36
|
+
T,
|
37
|
+
H: tl.constexpr,
|
38
|
+
Hg: tl.constexpr,
|
39
|
+
K: tl.constexpr,
|
40
|
+
V: tl.constexpr,
|
41
|
+
BT: tl.constexpr,
|
42
|
+
BK: tl.constexpr,
|
43
|
+
BV: tl.constexpr,
|
44
|
+
IS_VARLEN: tl.constexpr,
|
45
|
+
):
|
46
|
+
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
47
|
+
i_b, i_h = i_bh // H, i_bh % H
|
48
|
+
if IS_VARLEN:
|
49
|
+
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(
|
50
|
+
chunk_indices + i_t * 2 + 1
|
51
|
+
).to(tl.int32)
|
52
|
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(
|
53
|
+
cu_seqlens + i_n + 1
|
54
|
+
).to(tl.int32)
|
55
|
+
T = eos - bos
|
56
|
+
else:
|
57
|
+
bos, eos = i_b * T, i_b * T + T
|
58
|
+
p_beta = tl.make_block_ptr(
|
59
|
+
beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
|
60
|
+
)
|
61
|
+
p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
|
62
|
+
p_A = tl.make_block_ptr(
|
63
|
+
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
|
64
|
+
)
|
65
|
+
b_beta = tl.load(p_beta, boundary_check=(0,))
|
66
|
+
b_A = tl.load(p_A, boundary_check=(0, 1))
|
67
|
+
b_g = tl.exp(tl.load(p_g, boundary_check=(0,)))
|
68
|
+
|
69
|
+
for i_v in range(tl.cdiv(V, BV)):
|
70
|
+
p_v = tl.make_block_ptr(
|
71
|
+
v + (bos * H + i_h) * V,
|
72
|
+
(T, V),
|
73
|
+
(H * V, 1),
|
74
|
+
(i_t * BT, i_v * BV),
|
75
|
+
(BT, BV),
|
76
|
+
(1, 0),
|
77
|
+
)
|
78
|
+
p_u = tl.make_block_ptr(
|
79
|
+
u + (bos * H + i_h) * V,
|
80
|
+
(T, V),
|
81
|
+
(H * V, 1),
|
82
|
+
(i_t * BT, i_v * BV),
|
83
|
+
(BT, BV),
|
84
|
+
(1, 0),
|
85
|
+
)
|
86
|
+
b_v = tl.load(p_v, boundary_check=(0, 1))
|
87
|
+
b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
|
88
|
+
b_u = tl.dot(b_A, b_vb, allow_tf32=False)
|
89
|
+
tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
|
90
|
+
|
91
|
+
for i_k in range(tl.cdiv(K, BK)):
|
92
|
+
p_k = tl.make_block_ptr(
|
93
|
+
k + (bos * Hg + i_h // (H // Hg)) * K,
|
94
|
+
(T, K),
|
95
|
+
(Hg * K, 1),
|
96
|
+
(i_t * BT, i_k * BK),
|
97
|
+
(BT, BK),
|
98
|
+
(1, 0),
|
99
|
+
)
|
100
|
+
p_w = tl.make_block_ptr(
|
101
|
+
w + (bos * H + i_h) * K,
|
102
|
+
(T, K),
|
103
|
+
(H * K, 1),
|
104
|
+
(i_t * BT, i_k * BK),
|
105
|
+
(BT, BK),
|
106
|
+
(1, 0),
|
107
|
+
)
|
108
|
+
b_k = tl.load(p_k, boundary_check=(0, 1))
|
109
|
+
b_kb = (b_k * b_beta[:, None] * b_g[:, None]).to(b_k.dtype)
|
110
|
+
b_w = tl.dot(b_A, b_kb)
|
111
|
+
tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
|
112
|
+
|
113
|
+
|
114
|
+
def recompute_w_u_fwd(
|
115
|
+
k: torch.Tensor,
|
116
|
+
v: torch.Tensor,
|
117
|
+
beta: torch.Tensor,
|
118
|
+
g_cumsum: torch.Tensor,
|
119
|
+
A: torch.Tensor,
|
120
|
+
cu_seqlens: Optional[torch.LongTensor],
|
121
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
122
|
+
B, T, Hg, K, V = *k.shape, v.shape[-1]
|
123
|
+
H = v.shape[-2]
|
124
|
+
BT = A.shape[-1]
|
125
|
+
|
126
|
+
chunk_indices = (
|
127
|
+
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
128
|
+
)
|
129
|
+
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
130
|
+
BK = 64
|
131
|
+
BV = 64
|
132
|
+
u = torch.empty_like(v)
|
133
|
+
w = k.new_empty(B, T, H, K)
|
134
|
+
recompute_w_u_fwd_kernel[(NT, B * H)](
|
135
|
+
k=k,
|
136
|
+
v=v,
|
137
|
+
beta=beta,
|
138
|
+
w=w,
|
139
|
+
u=u,
|
140
|
+
A=A,
|
141
|
+
g=g_cumsum,
|
142
|
+
cu_seqlens=cu_seqlens,
|
143
|
+
chunk_indices=chunk_indices,
|
144
|
+
T=T,
|
145
|
+
H=H,
|
146
|
+
Hg=Hg,
|
147
|
+
K=K,
|
148
|
+
V=V,
|
149
|
+
BT=BT,
|
150
|
+
BK=BK,
|
151
|
+
BV=BV,
|
152
|
+
num_warps=4,
|
153
|
+
num_stages=3,
|
154
|
+
)
|
155
|
+
return w, u
|
156
|
+
|
157
|
+
|
158
|
+
fwd_recompute_w_u = recompute_w_u_fwd
|