sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +10 -1
  3. sglang/bench_serving.py +251 -26
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/internvl.py +6 -0
  7. sglang/srt/configs/longcat_flash.py +104 -0
  8. sglang/srt/configs/model_config.py +37 -7
  9. sglang/srt/configs/qwen3_next.py +326 -0
  10. sglang/srt/connector/__init__.py +1 -1
  11. sglang/srt/connector/base_connector.py +1 -2
  12. sglang/srt/connector/redis.py +2 -2
  13. sglang/srt/connector/serde/__init__.py +1 -1
  14. sglang/srt/connector/serde/safe_serde.py +4 -3
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/ascend/conn.py +75 -0
  21. sglang/srt/disaggregation/base/conn.py +1 -1
  22. sglang/srt/disaggregation/common/conn.py +15 -12
  23. sglang/srt/disaggregation/decode.py +6 -4
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -420
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +6 -4
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +94 -58
  31. sglang/srt/entrypoints/engine.py +34 -14
  32. sglang/srt/entrypoints/http_server.py +172 -47
  33. sglang/srt/entrypoints/openai/protocol.py +63 -3
  34. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  35. sglang/srt/entrypoints/openai/serving_chat.py +34 -19
  36. sglang/srt/entrypoints/openai/serving_completions.py +10 -4
  37. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  38. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  39. sglang/srt/eplb/eplb_manager.py +28 -4
  40. sglang/srt/eplb/expert_distribution.py +55 -15
  41. sglang/srt/eplb/expert_location.py +8 -3
  42. sglang/srt/eplb/expert_location_updater.py +1 -1
  43. sglang/srt/function_call/ebnf_composer.py +11 -9
  44. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  45. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  46. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  47. sglang/srt/hf_transformers_utils.py +12 -0
  48. sglang/srt/layers/activation.py +44 -9
  49. sglang/srt/layers/attention/aiter_backend.py +93 -68
  50. sglang/srt/layers/attention/ascend_backend.py +250 -112
  51. sglang/srt/layers/attention/fla/chunk.py +242 -0
  52. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  53. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  54. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  55. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  56. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  57. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  58. sglang/srt/layers/attention/fla/index.py +37 -0
  59. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  60. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  61. sglang/srt/layers/attention/fla/op.py +66 -0
  62. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  63. sglang/srt/layers/attention/fla/utils.py +331 -0
  64. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  65. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  66. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  67. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  68. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  69. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  70. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  71. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  72. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  73. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  74. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  75. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  76. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  77. sglang/srt/layers/communicator.py +45 -7
  78. sglang/srt/layers/layernorm.py +54 -12
  79. sglang/srt/layers/logits_processor.py +10 -3
  80. sglang/srt/layers/moe/__init__.py +2 -1
  81. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  82. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  83. sglang/srt/layers/moe/ep_moe/layer.py +110 -49
  84. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  85. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  94. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  95. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  96. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  97. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  98. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  99. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  100. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  101. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  102. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  103. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  104. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  105. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  106. sglang/srt/layers/moe/topk.py +43 -12
  107. sglang/srt/layers/moe/utils.py +6 -5
  108. sglang/srt/layers/quantization/awq.py +19 -7
  109. sglang/srt/layers/quantization/base_config.py +11 -6
  110. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  111. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  112. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  113. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  114. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  115. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  116. sglang/srt/layers/quantization/fp8.py +76 -47
  117. sglang/srt/layers/quantization/fp8_utils.py +43 -29
  118. sglang/srt/layers/quantization/gptq.py +25 -17
  119. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  120. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  121. sglang/srt/layers/quantization/mxfp4.py +77 -45
  122. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  123. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  124. sglang/srt/layers/quantization/quark/utils.py +97 -0
  125. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  126. sglang/srt/layers/quantization/unquant.py +135 -47
  127. sglang/srt/layers/quantization/utils.py +13 -0
  128. sglang/srt/layers/quantization/w4afp8.py +60 -42
  129. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  130. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  131. sglang/srt/layers/rocm_linear_utils.py +44 -0
  132. sglang/srt/layers/rotary_embedding.py +28 -19
  133. sglang/srt/layers/sampler.py +29 -5
  134. sglang/srt/lora/backend/base_backend.py +50 -8
  135. sglang/srt/lora/backend/triton_backend.py +90 -2
  136. sglang/srt/lora/layers.py +32 -0
  137. sglang/srt/lora/lora.py +4 -1
  138. sglang/srt/lora/lora_manager.py +35 -112
  139. sglang/srt/lora/mem_pool.py +24 -10
  140. sglang/srt/lora/utils.py +18 -9
  141. sglang/srt/managers/cache_controller.py +242 -278
  142. sglang/srt/managers/data_parallel_controller.py +30 -15
  143. sglang/srt/managers/detokenizer_manager.py +13 -2
  144. sglang/srt/managers/disagg_service.py +46 -0
  145. sglang/srt/managers/io_struct.py +160 -11
  146. sglang/srt/managers/mm_utils.py +6 -1
  147. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  148. sglang/srt/managers/schedule_batch.py +27 -44
  149. sglang/srt/managers/schedule_policy.py +4 -3
  150. sglang/srt/managers/scheduler.py +90 -115
  151. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  152. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  153. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  154. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  155. sglang/srt/managers/template_manager.py +3 -3
  156. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  157. sglang/srt/managers/tokenizer_manager.py +41 -477
  158. sglang/srt/managers/tp_worker.py +16 -4
  159. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  160. sglang/srt/mem_cache/allocator.py +1 -1
  161. sglang/srt/mem_cache/chunk_cache.py +1 -1
  162. sglang/srt/mem_cache/hicache_storage.py +24 -22
  163. sglang/srt/mem_cache/hiradix_cache.py +184 -101
  164. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  165. sglang/srt/mem_cache/memory_pool.py +324 -41
  166. sglang/srt/mem_cache/memory_pool_host.py +25 -18
  167. sglang/srt/mem_cache/radix_cache.py +5 -6
  168. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  169. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  170. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  171. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  172. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
  173. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  174. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  175. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
  176. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  177. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  178. sglang/srt/metrics/collector.py +484 -63
  179. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  180. sglang/srt/metrics/utils.py +48 -0
  181. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  182. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  183. sglang/srt/model_executor/forward_batch_info.py +72 -18
  184. sglang/srt/model_executor/model_runner.py +189 -31
  185. sglang/srt/model_loader/__init__.py +9 -3
  186. sglang/srt/model_loader/loader.py +33 -28
  187. sglang/srt/model_loader/utils.py +12 -0
  188. sglang/srt/model_loader/weight_utils.py +2 -1
  189. sglang/srt/models/deepseek_v2.py +311 -50
  190. sglang/srt/models/gemma3n_mm.py +1 -1
  191. sglang/srt/models/glm4_moe.py +10 -1
  192. sglang/srt/models/glm4v.py +4 -2
  193. sglang/srt/models/gpt_oss.py +5 -18
  194. sglang/srt/models/internvl.py +28 -0
  195. sglang/srt/models/llama4.py +9 -0
  196. sglang/srt/models/llama_eagle3.py +17 -0
  197. sglang/srt/models/longcat_flash.py +1026 -0
  198. sglang/srt/models/longcat_flash_nextn.py +699 -0
  199. sglang/srt/models/minicpmv.py +165 -3
  200. sglang/srt/models/mllama4.py +25 -0
  201. sglang/srt/models/opt.py +637 -0
  202. sglang/srt/models/qwen2.py +33 -3
  203. sglang/srt/models/qwen2_5_vl.py +90 -42
  204. sglang/srt/models/qwen2_moe.py +79 -14
  205. sglang/srt/models/qwen3.py +8 -2
  206. sglang/srt/models/qwen3_moe.py +39 -8
  207. sglang/srt/models/qwen3_next.py +1039 -0
  208. sglang/srt/models/qwen3_next_mtp.py +109 -0
  209. sglang/srt/models/torch_native_llama.py +1 -1
  210. sglang/srt/models/transformers.py +1 -1
  211. sglang/srt/multimodal/processors/base_processor.py +4 -2
  212. sglang/srt/multimodal/processors/glm4v.py +9 -9
  213. sglang/srt/multimodal/processors/internvl.py +141 -129
  214. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  215. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  216. sglang/srt/sampling/sampling_batch_info.py +18 -15
  217. sglang/srt/server_args.py +297 -79
  218. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  219. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  220. sglang/srt/speculative/eagle_worker.py +216 -120
  221. sglang/srt/speculative/spec_info.py +5 -0
  222. sglang/srt/speculative/standalone_worker.py +109 -0
  223. sglang/srt/utils.py +37 -2
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  226. sglang/test/few_shot_gsm8k.py +1 -0
  227. sglang/test/runners.py +4 -0
  228. sglang/test/test_cutlass_moe.py +24 -6
  229. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  230. sglang/test/test_disaggregation_utils.py +66 -0
  231. sglang/test/test_utils.py +25 -1
  232. sglang/utils.py +5 -0
  233. sglang/version.py +1 -1
  234. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
  235. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
  236. sglang/srt/disaggregation/launch_lb.py +0 -131
  237. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  238. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  239. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  240. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  241. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  242. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  243. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  244. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  245. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,331 @@
1
+ # Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/utils.py
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import contextlib
5
+ import functools
6
+ import logging
7
+ import os
8
+ import sys
9
+ from enum import Enum
10
+ from functools import lru_cache
11
+ from typing import Any, Callable, Dict, Literal, Optional, Tuple
12
+
13
+ import torch
14
+ import triton
15
+ from packaging import version
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1"
20
+ FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1"
21
+
22
+
23
+ @lru_cache(maxsize=1)
24
+ def check_environments():
25
+ """
26
+ Checks the current operating system, Triton version, and Python version,
27
+ issuing warnings if they don't meet recommendations.
28
+ This function's body only runs once due to lru_cache.
29
+ """
30
+ # Check Operating System
31
+ if sys.platform == "win32":
32
+ logger.warning(
33
+ "Detected Windows operating system. Triton does not have an official Windows release, "
34
+ "thus FLA will not be adapted for Windows, and any potential errors will not be fixed. "
35
+ "Please consider using a Linux environment for compatibility."
36
+ )
37
+
38
+ triton_version = version.parse(triton.__version__)
39
+ required_triton_version = version.parse("3.2.0")
40
+
41
+ if triton_version < required_triton_version:
42
+ logger.warning(
43
+ f"Current Triton version {triton_version} is below the recommended 3.2.0 version. "
44
+ "Errors may occur and these issues will not be fixed. "
45
+ "Please consider upgrading Triton."
46
+ )
47
+
48
+ # Check Python version
49
+ py_version = version.parse(f"{sys.version_info.major}.{sys.version_info.minor}")
50
+ required_py_version = version.parse("3.11")
51
+
52
+ if py_version < required_py_version:
53
+ logger.warning(
54
+ f"Current Python version {py_version} is below the recommended 3.11 version. "
55
+ "It is recommended to upgrade to Python 3.11 or higher for the best experience."
56
+ )
57
+
58
+ return None
59
+
60
+
61
+ check_environments()
62
+
63
+
64
+ def get_abs_err(x, y):
65
+ return (x.detach() - y.detach()).flatten().abs().max().item()
66
+
67
+
68
+ def get_err_ratio(x, y):
69
+ err = (x.detach() - y.detach()).flatten().square().mean().sqrt().item()
70
+ base = (x.detach()).flatten().square().mean().sqrt().item()
71
+ return err / (base + 1e-8)
72
+
73
+
74
+ def assert_close(prefix, ref, tri, ratio, warning=False, err_atol=1e-6):
75
+ abs_atol = get_abs_err(ref, tri)
76
+ msg = f"{prefix} diff: {abs_atol:.6f} ratio: {get_err_ratio(ref, tri):.6f}"
77
+ logger.info(msg)
78
+ error_rate = get_err_ratio(ref, tri)
79
+ if abs_atol <= err_atol:
80
+ return
81
+ if warning or (FLA_CI_ENV and (error_rate < 0.01 or abs_atol <= 0.3)):
82
+ if error_rate > ratio:
83
+ import warnings
84
+
85
+ warnings.warn(msg)
86
+ else:
87
+ assert error_rate < ratio, msg
88
+
89
+
90
+ SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0"))
91
+
92
+
93
+ def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
94
+ """
95
+ A decorator that caches the most recent results of a function with tensor inputs.
96
+ This decorator will store the output of the decorated function for the most recent set of input tensors.
97
+ The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed.
98
+ Args:
99
+ fn (Callable[..., torch.Tensor]):
100
+ The function to be decorated. It should take tensor inputs and return tensor outputs.
101
+ Returns:
102
+ Callable[..., torch.Tensor]:
103
+ A wrapped version of the input function with single-entry caching.
104
+ """
105
+
106
+ cache_entries: Tuple[Optional[Tuple], Optional[Dict], Any] = []
107
+ cache_size = 4
108
+
109
+ @functools.wraps(fn)
110
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
111
+ nonlocal cache_entries, cache_size
112
+ for i, entry in enumerate(cache_entries):
113
+ last_args, last_kwargs, last_result = entry
114
+ if len(args) == len(last_args) and len(kwargs) == len(last_kwargs):
115
+ if all(a is b for a, b in zip(args, last_args)) and all(
116
+ k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()
117
+ ):
118
+ cache_entries = (
119
+ cache_entries[:i]
120
+ + cache_entries[i + 1 :]
121
+ + [(args, kwargs, last_result)]
122
+ )
123
+ return last_result
124
+
125
+ result = fn(*args, **kwargs)
126
+
127
+ if len(cache_entries) >= cache_size:
128
+ cache_entries = cache_entries[1:]
129
+ cache_entries.append((args, kwargs, result))
130
+ return result
131
+
132
+ return wrapper
133
+
134
+
135
+ def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
136
+ """
137
+ A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
138
+ """
139
+
140
+ @functools.wraps(fn)
141
+ def wrapper(*args, **kwargs):
142
+ contiguous_args = (
143
+ i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args
144
+ )
145
+ contiguous_kwargs = {
146
+ k: (v if not isinstance(v, torch.Tensor) else v.contiguous())
147
+ for k, v in kwargs.items()
148
+ }
149
+
150
+ tensor = None
151
+ for arg in args:
152
+ if isinstance(arg, torch.Tensor):
153
+ tensor = arg
154
+ break
155
+ if tensor is None:
156
+ for value in kwargs.values():
157
+ if isinstance(value, torch.Tensor):
158
+ tensor = value
159
+ break
160
+
161
+ if tensor is not None:
162
+ ctx = custom_device_ctx(tensor.device.index)
163
+ else:
164
+ ctx = contextlib.nullcontext()
165
+
166
+ with ctx:
167
+ return fn(*contiguous_args, **contiguous_kwargs)
168
+
169
+ return wrapper
170
+
171
+
172
+ contiguous = input_guard
173
+
174
+
175
+ def require_version(version, hint):
176
+ """
177
+ Perform a runtime check of the dependency versions, using the exact same syntax used by pip.
178
+ """
179
+
180
+ def decorator(fn):
181
+ @functools.wraps(fn)
182
+ def wrapper(ctx, *args, **kwargs):
183
+ from transformers.utils.versions import require_version
184
+
185
+ require_version(version, hint)
186
+ return fn(
187
+ ctx,
188
+ *(
189
+ i if not isinstance(i, torch.Tensor) else i.contiguous()
190
+ for i in args
191
+ ),
192
+ **{
193
+ k: (v if not isinstance(v, torch.Tensor) else v.contiguous())
194
+ for k, v in kwargs.items()
195
+ },
196
+ )
197
+
198
+ return wrapper
199
+
200
+ return decorator
201
+
202
+
203
+ def checkpoint(fn):
204
+ def wrapper(*args, **kwargs):
205
+ return torch.utils.checkpoint.checkpoint(fn, *args, **kwargs)
206
+
207
+ return wrapper
208
+
209
+
210
+ @lru_cache(maxsize=None)
211
+ def check_pytorch_version(version_s: str = "2.4") -> bool:
212
+ return version.parse(torch.__version__) >= version.parse(version_s)
213
+
214
+
215
+ def _cpu_device_warning():
216
+ import warnings
217
+
218
+ warnings.warn(
219
+ ("Triton is not supported on current platform, roll back to CPU."), stacklevel=1
220
+ )
221
+
222
+
223
+ @lru_cache(maxsize=None)
224
+ def get_multiprocessor_count(tensor_idx: int = 0) -> int:
225
+ try:
226
+ return triton.runtime.driver.active.utils.get_device_properties(tensor_idx)[
227
+ "multiprocessor_count"
228
+ ]
229
+ except BaseException:
230
+ _cpu_device_warning()
231
+ return -1
232
+
233
+
234
+ @lru_cache(maxsize=None)
235
+ def get_available_device() -> str:
236
+ try:
237
+ return triton.runtime.driver.active.get_current_target().backend
238
+ except BaseException:
239
+ _cpu_device_warning()
240
+ return "cpu"
241
+
242
+
243
+ @lru_cache(maxsize=None)
244
+ def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]:
245
+ device = get_available_device()
246
+ if device == "cuda":
247
+ return "nvidia"
248
+ elif device == "hip":
249
+ return "amd"
250
+ elif device == "xpu":
251
+ return "intel"
252
+ else:
253
+ return device
254
+
255
+
256
+ # For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
257
+ # However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
258
+ # Therefore, we need to check the triton backend to determine the actual GPU vendor.
259
+ device = get_available_device() if get_available_device() != "hip" else "cuda"
260
+ device_torch_lib = getattr(torch, device)
261
+ device_platform = _check_platform()
262
+
263
+ is_amd = device_platform == "amd"
264
+ is_intel = device_platform == "intel"
265
+ is_nvidia = device_platform == "nvidia"
266
+ is_intel_alchemist = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name(0)
267
+ is_nvidia_hopper = is_nvidia and (
268
+ "NVIDIA H" in torch.cuda.get_device_name(0)
269
+ or torch.cuda.get_device_capability()[0] >= 9
270
+ )
271
+ use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1"
272
+
273
+ # Nvidia Ampere or newer, haven't check AMD and intel yet.
274
+ is_tf32_supported = is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8
275
+ is_gather_supported = hasattr(triton.language, "gather")
276
+
277
+
278
+ def get_all_max_shared_mem():
279
+ try:
280
+ return [
281
+ triton.runtime.driver.active.utils.get_device_properties(i)[
282
+ "max_shared_mem"
283
+ ]
284
+ for i in range(device_torch_lib.device_count())
285
+ ]
286
+ except BaseException:
287
+ _cpu_device_warning()
288
+ return [-1]
289
+
290
+
291
+ class Backend(Enum):
292
+ ADA = 101376 # RTX 4090
293
+ AMPERE = 166912 # A100
294
+ HOPPER = 232448 # H100
295
+ DEFAULT = 102400 # Default
296
+
297
+ @classmethod
298
+ def get_shared_memory(cls, arch: str) -> int:
299
+ try:
300
+ return cls[arch.upper()].value
301
+ except KeyError:
302
+ return cls.DEFAULT.value
303
+
304
+
305
+ @lru_cache(maxsize=None)
306
+ def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
307
+ try:
308
+ device_shared_mem_list = get_all_max_shared_mem()
309
+ max_shared_memory = device_shared_mem_list[tensor_idx]
310
+ return max_shared_memory >= Backend.get_shared_memory(arch)
311
+ except Exception:
312
+ return False
313
+
314
+
315
+ if check_pytorch_version("2.4"):
316
+ device = "cuda" if device == "cpu" else device
317
+ autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=device)
318
+ autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=device)
319
+
320
+ def custom_device_ctx(index: int):
321
+ return device_torch_lib.device(index)
322
+
323
+ else:
324
+ assert (
325
+ device == "cuda"
326
+ ), "Only cuda device is supported for PyTorch version < 2.4.0."
327
+ autocast_custom_fwd = device_torch_lib.amp.custom_fwd
328
+ autocast_custom_bwd = device_torch_lib.amp.custom_bwd
329
+
330
+ def custom_device_ctx(index: int):
331
+ return torch.cuda.device(index)
@@ -0,0 +1,158 @@
1
+ # Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gated_delta_rule/wy_fast.py
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
4
+
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+
11
+ from sglang.srt.layers.attention.fla.index import prepare_chunk_indices
12
+ from sglang.srt.layers.attention.fla.op import safe_exp
13
+ from sglang.srt.layers.attention.fla.utils import check_shared_mem
14
+
15
+
16
+ @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
17
+ # @triton.autotune(
18
+ # configs=[
19
+ # triton.Config({}, num_warps=num_warps, num_stages=num_stages)
20
+ # for num_warps in [2, 4, 8]
21
+ # for num_stages in [2, 3, 4]
22
+ # ],
23
+ # key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"],
24
+ # )
25
+ @triton.jit(do_not_specialize=["T"])
26
+ def recompute_w_u_fwd_kernel(
27
+ k,
28
+ v,
29
+ beta,
30
+ w,
31
+ u,
32
+ A,
33
+ g,
34
+ cu_seqlens,
35
+ chunk_indices,
36
+ T,
37
+ H: tl.constexpr,
38
+ Hg: tl.constexpr,
39
+ K: tl.constexpr,
40
+ V: tl.constexpr,
41
+ BT: tl.constexpr,
42
+ BK: tl.constexpr,
43
+ BV: tl.constexpr,
44
+ IS_VARLEN: tl.constexpr,
45
+ ):
46
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
47
+ i_b, i_h = i_bh // H, i_bh % H
48
+ if IS_VARLEN:
49
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(
50
+ chunk_indices + i_t * 2 + 1
51
+ ).to(tl.int32)
52
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(
53
+ cu_seqlens + i_n + 1
54
+ ).to(tl.int32)
55
+ T = eos - bos
56
+ else:
57
+ bos, eos = i_b * T, i_b * T + T
58
+ p_beta = tl.make_block_ptr(
59
+ beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
60
+ )
61
+ p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
62
+ p_A = tl.make_block_ptr(
63
+ A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
64
+ )
65
+ b_beta = tl.load(p_beta, boundary_check=(0,))
66
+ b_A = tl.load(p_A, boundary_check=(0, 1))
67
+ b_g = tl.exp(tl.load(p_g, boundary_check=(0,)))
68
+
69
+ for i_v in range(tl.cdiv(V, BV)):
70
+ p_v = tl.make_block_ptr(
71
+ v + (bos * H + i_h) * V,
72
+ (T, V),
73
+ (H * V, 1),
74
+ (i_t * BT, i_v * BV),
75
+ (BT, BV),
76
+ (1, 0),
77
+ )
78
+ p_u = tl.make_block_ptr(
79
+ u + (bos * H + i_h) * V,
80
+ (T, V),
81
+ (H * V, 1),
82
+ (i_t * BT, i_v * BV),
83
+ (BT, BV),
84
+ (1, 0),
85
+ )
86
+ b_v = tl.load(p_v, boundary_check=(0, 1))
87
+ b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
88
+ b_u = tl.dot(b_A, b_vb, allow_tf32=False)
89
+ tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
90
+
91
+ for i_k in range(tl.cdiv(K, BK)):
92
+ p_k = tl.make_block_ptr(
93
+ k + (bos * Hg + i_h // (H // Hg)) * K,
94
+ (T, K),
95
+ (Hg * K, 1),
96
+ (i_t * BT, i_k * BK),
97
+ (BT, BK),
98
+ (1, 0),
99
+ )
100
+ p_w = tl.make_block_ptr(
101
+ w + (bos * H + i_h) * K,
102
+ (T, K),
103
+ (H * K, 1),
104
+ (i_t * BT, i_k * BK),
105
+ (BT, BK),
106
+ (1, 0),
107
+ )
108
+ b_k = tl.load(p_k, boundary_check=(0, 1))
109
+ b_kb = (b_k * b_beta[:, None] * b_g[:, None]).to(b_k.dtype)
110
+ b_w = tl.dot(b_A, b_kb)
111
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
112
+
113
+
114
+ def recompute_w_u_fwd(
115
+ k: torch.Tensor,
116
+ v: torch.Tensor,
117
+ beta: torch.Tensor,
118
+ g_cumsum: torch.Tensor,
119
+ A: torch.Tensor,
120
+ cu_seqlens: Optional[torch.LongTensor],
121
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
122
+ B, T, Hg, K, V = *k.shape, v.shape[-1]
123
+ H = v.shape[-2]
124
+ BT = A.shape[-1]
125
+
126
+ chunk_indices = (
127
+ prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
128
+ )
129
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
130
+ BK = 64
131
+ BV = 64
132
+ u = torch.empty_like(v)
133
+ w = k.new_empty(B, T, H, K)
134
+ recompute_w_u_fwd_kernel[(NT, B * H)](
135
+ k=k,
136
+ v=v,
137
+ beta=beta,
138
+ w=w,
139
+ u=u,
140
+ A=A,
141
+ g=g_cumsum,
142
+ cu_seqlens=cu_seqlens,
143
+ chunk_indices=chunk_indices,
144
+ T=T,
145
+ H=H,
146
+ Hg=Hg,
147
+ K=K,
148
+ V=V,
149
+ BT=BT,
150
+ BK=BK,
151
+ BV=BV,
152
+ num_warps=4,
153
+ num_stages=3,
154
+ )
155
+ return w, u
156
+
157
+
158
+ fwd_recompute_w_u = recompute_w_u_fwd
@@ -501,8 +501,9 @@ class FlashInferAttnBackend(AttentionBackend):
501
501
  sm_scale=layer.scaling,
502
502
  window_left=layer.sliding_window_size,
503
503
  logits_soft_cap=logits_soft_cap,
504
- k_scale=layer.k_scale,
505
- v_scale=layer.v_scale,
504
+ # Must use _float to avoid device-to-host copy that breaks cuda graph capture.
505
+ k_scale=layer.k_scale_float,
506
+ v_scale=layer.v_scale_float,
506
507
  )
507
508
  else:
508
509
  causal = True
@@ -580,8 +581,9 @@ class FlashInferAttnBackend(AttentionBackend):
580
581
  forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
581
582
  sm_scale=layer.scaling,
582
583
  logits_soft_cap=layer.logit_cap,
583
- k_scale=layer.k_scale,
584
- v_scale=layer.v_scale,
584
+ # Must use _float to avoid device-to-host copy that breaks cuda graph capture.
585
+ k_scale=layer.k_scale_float,
586
+ v_scale=layer.v_scale_float,
585
587
  )
586
588
 
587
589
  return o.view(-1, layer.tp_q_head_num * layer.head_dim)
@@ -96,6 +96,7 @@ class FlashInferMhaChunkKVRunner:
96
96
  def update_wrapper(
97
97
  self,
98
98
  forward_batch: ForwardBatch,
99
+ disable_flashinfer_ragged: bool = False,
99
100
  ):
100
101
  assert forward_batch.num_prefix_chunks is not None
101
102
  num_prefix_chunks = forward_batch.num_prefix_chunks
@@ -128,16 +129,17 @@ class FlashInferMhaChunkKVRunner:
128
129
  causal=False,
129
130
  )
130
131
  # ragged prefill
131
- self.ragged_wrapper.begin_forward(
132
- qo_indptr=qo_indptr,
133
- kv_indptr=qo_indptr,
134
- num_qo_heads=self.num_local_heads,
135
- num_kv_heads=self.num_local_heads,
136
- head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
137
- head_dim_vo=self.v_head_dim,
138
- q_data_type=self.q_data_type,
139
- causal=True,
140
- )
132
+ if not disable_flashinfer_ragged:
133
+ self.ragged_wrapper.begin_forward(
134
+ qo_indptr=qo_indptr,
135
+ kv_indptr=qo_indptr,
136
+ num_qo_heads=self.num_local_heads,
137
+ num_kv_heads=self.num_local_heads,
138
+ head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
139
+ head_dim_vo=self.v_head_dim,
140
+ q_data_type=self.q_data_type,
141
+ causal=True,
142
+ )
141
143
 
142
144
  def forward(
143
145
  self,
@@ -491,9 +493,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
491
493
  def get_cuda_graph_seq_len_fill_value(self):
492
494
  return 1
493
495
 
494
- def init_mha_chunk_metadata(self, forward_batch: ForwardBatch):
496
+ def init_mha_chunk_metadata(
497
+ self, forward_batch: ForwardBatch, disable_flashinfer_ragged: bool = False
498
+ ):
495
499
  """Init the metadata for a forward pass."""
496
- self.mha_chunk_kv_cache.update_wrapper(forward_batch)
500
+ self.mha_chunk_kv_cache.update_wrapper(forward_batch, disable_flashinfer_ragged)
497
501
 
498
502
  def forward_extend(
499
503
  self,
@@ -5,6 +5,7 @@ import torch
5
5
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
6
6
  from sglang.srt.layers.radix_attention import RadixAttention
7
7
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
8
+ from sglang.srt.model_executor.model_runner import ModelRunner
8
9
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
9
10
 
10
11
 
@@ -12,19 +13,54 @@ class HybridAttnBackend(AttentionBackend):
12
13
  """Support different backends for prefill and decode."""
13
14
 
14
15
  def __init__(
15
- self, prefill_backend: AttentionBackend, decode_backend: AttentionBackend
16
+ self,
17
+ model_runner: ModelRunner,
18
+ prefill_backend: AttentionBackend,
19
+ decode_backend: AttentionBackend,
16
20
  ):
21
+ self.model_runner = model_runner
17
22
  self.prefill_backend = prefill_backend
18
23
  self.decode_backend = decode_backend
19
24
 
20
- def init_forward_metadata(self, forward_batch: ForwardBatch):
21
- if forward_batch.forward_mode.is_decode():
22
- self.decode_backend.init_forward_metadata(forward_batch)
25
+ def _select_backend(self, forward_mode: ForwardMode) -> AttentionBackend:
26
+ """
27
+ Select the appropriate attention backend based on the forward mode.
28
+
29
+ Args:
30
+ forward_mode: The current forward mode indicating the operation type
31
+
32
+ Returns:
33
+ The selected attention backend (prefill or decode)
34
+
35
+ Note:
36
+ - decode_or_idle: Always uses decode backend
37
+ - target_verify or draft_extend: Uses decode backend if speculative_attention_mode is "decode", otherwise prefill backend
38
+ - prefill: Always uses prefill backend
39
+ """
40
+ if forward_mode.is_decode_or_idle():
41
+ return self.decode_backend
42
+ elif forward_mode.is_target_verify() or forward_mode.is_draft_extend():
43
+ return (
44
+ self.decode_backend
45
+ if self.model_runner.server_args.speculative_attention_mode == "decode"
46
+ else self.prefill_backend
47
+ )
23
48
  else:
24
- self.prefill_backend.init_forward_metadata(forward_batch)
49
+ return self.prefill_backend
50
+
51
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
52
+ backend = self._select_backend(forward_batch.forward_mode)
53
+ backend.init_forward_metadata(forward_batch)
25
54
 
26
55
  def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
27
56
  self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens)
57
+ if (
58
+ self.model_runner.server_args.speculative_algorithm is not None
59
+ and self.model_runner.server_args.speculative_attention_mode == "prefill"
60
+ ):
61
+ # When speculative decoding is enabled, we need to initialize the backend
62
+ # that will be used for target_verify.
63
+ self.prefill_backend.init_cuda_graph_state(max_bs, max_num_tokens)
28
64
 
29
65
  def init_forward_metadata_capture_cuda_graph(
30
66
  self,
@@ -36,7 +72,8 @@ class HybridAttnBackend(AttentionBackend):
36
72
  forward_mode: ForwardMode,
37
73
  spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
38
74
  ):
39
- self.decode_backend.init_forward_metadata_capture_cuda_graph(
75
+ backend = self._select_backend(forward_mode)
76
+ backend.init_forward_metadata_capture_cuda_graph(
40
77
  bs,
41
78
  num_tokens,
42
79
  req_pool_indices,
@@ -57,7 +94,8 @@ class HybridAttnBackend(AttentionBackend):
57
94
  spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
58
95
  seq_lens_cpu: Optional[torch.Tensor],
59
96
  ):
60
- self.decode_backend.init_forward_metadata_replay_cuda_graph(
97
+ backend = self._select_backend(forward_mode)
98
+ backend.init_forward_metadata_replay_cuda_graph(
61
99
  bs,
62
100
  req_pool_indices,
63
101
  seq_lens,
@@ -95,6 +133,7 @@ class HybridAttnBackend(AttentionBackend):
95
133
  save_kv_cache: bool = True,
96
134
  **kwargs,
97
135
  ):
98
- return self.prefill_backend.forward_extend(
136
+ backend = self._select_backend(forward_batch.forward_mode)
137
+ return backend.forward_extend(
99
138
  q, k, v, layer, forward_batch, save_kv_cache, **kwargs
100
139
  )