sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (256) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +89 -54
  3. sglang/bench_serving.py +437 -40
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/profiler.py +0 -1
  6. sglang/srt/configs/__init__.py +4 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/longcat_flash.py +104 -0
  9. sglang/srt/configs/model_config.py +37 -7
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +1 -1
  12. sglang/srt/connector/base_connector.py +1 -2
  13. sglang/srt/connector/redis.py +2 -2
  14. sglang/srt/connector/serde/__init__.py +1 -1
  15. sglang/srt/connector/serde/safe_serde.py +4 -3
  16. sglang/srt/custom_op.py +11 -1
  17. sglang/srt/debug_utils/dump_comparator.py +81 -44
  18. sglang/srt/debug_utils/dump_loader.py +97 -0
  19. sglang/srt/debug_utils/dumper.py +11 -3
  20. sglang/srt/debug_utils/text_comparator.py +73 -11
  21. sglang/srt/disaggregation/ascend/conn.py +75 -0
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +6 -4
  25. sglang/srt/disaggregation/fake/conn.py +1 -1
  26. sglang/srt/disaggregation/mini_lb.py +6 -420
  27. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  28. sglang/srt/disaggregation/nixl/conn.py +180 -16
  29. sglang/srt/disaggregation/prefill.py +6 -4
  30. sglang/srt/disaggregation/utils.py +5 -50
  31. sglang/srt/distributed/parallel_state.py +94 -58
  32. sglang/srt/entrypoints/engine.py +34 -14
  33. sglang/srt/entrypoints/http_server.py +172 -47
  34. sglang/srt/entrypoints/openai/protocol.py +90 -27
  35. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  36. sglang/srt/entrypoints/openai/serving_chat.py +82 -26
  37. sglang/srt/entrypoints/openai/serving_completions.py +25 -4
  38. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  39. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  40. sglang/srt/eplb/eplb_manager.py +28 -4
  41. sglang/srt/eplb/expert_distribution.py +55 -15
  42. sglang/srt/eplb/expert_location.py +8 -3
  43. sglang/srt/eplb/expert_location_updater.py +1 -1
  44. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  45. sglang/srt/function_call/ebnf_composer.py +11 -9
  46. sglang/srt/function_call/function_call_parser.py +2 -0
  47. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  48. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  49. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  50. sglang/srt/hf_transformers_utils.py +28 -7
  51. sglang/srt/layers/activation.py +44 -9
  52. sglang/srt/layers/attention/aiter_backend.py +93 -68
  53. sglang/srt/layers/attention/ascend_backend.py +381 -136
  54. sglang/srt/layers/attention/fla/chunk.py +242 -0
  55. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  56. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  57. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  58. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  59. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  60. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  61. sglang/srt/layers/attention/fla/index.py +37 -0
  62. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  63. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  64. sglang/srt/layers/attention/fla/op.py +66 -0
  65. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  66. sglang/srt/layers/attention/fla/utils.py +331 -0
  67. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  68. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  69. sglang/srt/layers/attention/flashinfer_backend.py +11 -6
  70. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
  71. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  72. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  73. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  74. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  75. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  76. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  77. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  78. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  79. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  80. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  81. sglang/srt/layers/communicator.py +45 -8
  82. sglang/srt/layers/layernorm.py +54 -12
  83. sglang/srt/layers/logits_processor.py +10 -3
  84. sglang/srt/layers/moe/__init__.py +2 -1
  85. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  86. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  87. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  88. sglang/srt/layers/moe/ep_moe/layer.py +111 -56
  89. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  90. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  94. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  100. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  101. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  102. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  103. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  104. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  105. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  106. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  107. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  108. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  109. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  110. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  111. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  112. sglang/srt/layers/moe/topk.py +43 -12
  113. sglang/srt/layers/moe/utils.py +6 -5
  114. sglang/srt/layers/quantization/awq.py +19 -7
  115. sglang/srt/layers/quantization/base_config.py +11 -6
  116. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  119. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  121. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
  122. sglang/srt/layers/quantization/fp8.py +78 -48
  123. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  124. sglang/srt/layers/quantization/fp8_utils.py +45 -31
  125. sglang/srt/layers/quantization/gptq.py +25 -17
  126. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  127. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  128. sglang/srt/layers/quantization/mxfp4.py +93 -68
  129. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  130. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  131. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  132. sglang/srt/layers/quantization/quark/utils.py +97 -0
  133. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  134. sglang/srt/layers/quantization/unquant.py +135 -47
  135. sglang/srt/layers/quantization/utils.py +13 -0
  136. sglang/srt/layers/quantization/w4afp8.py +60 -42
  137. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  138. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  139. sglang/srt/layers/rocm_linear_utils.py +44 -0
  140. sglang/srt/layers/rotary_embedding.py +28 -19
  141. sglang/srt/layers/sampler.py +29 -5
  142. sglang/srt/layers/utils.py +0 -14
  143. sglang/srt/lora/backend/base_backend.py +50 -8
  144. sglang/srt/lora/backend/triton_backend.py +90 -2
  145. sglang/srt/lora/layers.py +32 -0
  146. sglang/srt/lora/lora.py +4 -1
  147. sglang/srt/lora/lora_manager.py +35 -112
  148. sglang/srt/lora/mem_pool.py +24 -10
  149. sglang/srt/lora/utils.py +18 -9
  150. sglang/srt/managers/cache_controller.py +396 -365
  151. sglang/srt/managers/data_parallel_controller.py +30 -15
  152. sglang/srt/managers/detokenizer_manager.py +18 -2
  153. sglang/srt/managers/disagg_service.py +46 -0
  154. sglang/srt/managers/io_struct.py +190 -11
  155. sglang/srt/managers/mm_utils.py +6 -1
  156. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  157. sglang/srt/managers/schedule_batch.py +27 -44
  158. sglang/srt/managers/schedule_policy.py +4 -3
  159. sglang/srt/managers/scheduler.py +148 -122
  160. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  161. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  162. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  163. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  164. sglang/srt/managers/template_manager.py +3 -3
  165. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  166. sglang/srt/managers/tokenizer_manager.py +77 -480
  167. sglang/srt/managers/tp_worker.py +16 -4
  168. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  169. sglang/srt/mem_cache/allocator.py +1 -1
  170. sglang/srt/mem_cache/chunk_cache.py +1 -1
  171. sglang/srt/mem_cache/hicache_storage.py +53 -40
  172. sglang/srt/mem_cache/hiradix_cache.py +196 -104
  173. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  174. sglang/srt/mem_cache/memory_pool.py +395 -53
  175. sglang/srt/mem_cache/memory_pool_host.py +27 -19
  176. sglang/srt/mem_cache/radix_cache.py +6 -6
  177. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  178. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  179. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  180. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  181. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
  182. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  183. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  184. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
  185. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  186. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  187. sglang/srt/metrics/collector.py +484 -63
  188. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  189. sglang/srt/metrics/utils.py +48 -0
  190. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  191. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  192. sglang/srt/model_executor/forward_batch_info.py +72 -18
  193. sglang/srt/model_executor/model_runner.py +190 -32
  194. sglang/srt/model_loader/__init__.py +9 -3
  195. sglang/srt/model_loader/loader.py +33 -28
  196. sglang/srt/model_loader/utils.py +12 -0
  197. sglang/srt/model_loader/weight_utils.py +2 -1
  198. sglang/srt/models/deepseek_v2.py +323 -53
  199. sglang/srt/models/gemma3n_mm.py +1 -1
  200. sglang/srt/models/glm4_moe.py +10 -1
  201. sglang/srt/models/glm4v.py +4 -2
  202. sglang/srt/models/gpt_oss.py +7 -19
  203. sglang/srt/models/internvl.py +28 -0
  204. sglang/srt/models/llama4.py +9 -0
  205. sglang/srt/models/llama_eagle3.py +17 -0
  206. sglang/srt/models/longcat_flash.py +1026 -0
  207. sglang/srt/models/longcat_flash_nextn.py +699 -0
  208. sglang/srt/models/minicpmv.py +165 -3
  209. sglang/srt/models/mllama4.py +25 -0
  210. sglang/srt/models/opt.py +637 -0
  211. sglang/srt/models/qwen2.py +33 -3
  212. sglang/srt/models/qwen2_5_vl.py +91 -42
  213. sglang/srt/models/qwen2_moe.py +79 -14
  214. sglang/srt/models/qwen3.py +8 -2
  215. sglang/srt/models/qwen3_moe.py +39 -8
  216. sglang/srt/models/qwen3_next.py +1039 -0
  217. sglang/srt/models/qwen3_next_mtp.py +109 -0
  218. sglang/srt/models/torch_native_llama.py +1 -1
  219. sglang/srt/models/transformers.py +1 -1
  220. sglang/srt/multimodal/processors/base_processor.py +4 -2
  221. sglang/srt/multimodal/processors/glm4v.py +9 -9
  222. sglang/srt/multimodal/processors/internvl.py +141 -129
  223. sglang/srt/{conversation.py → parser/conversation.py} +38 -5
  224. sglang/srt/parser/harmony_parser.py +588 -0
  225. sglang/srt/parser/reasoning_parser.py +309 -0
  226. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  227. sglang/srt/sampling/sampling_batch_info.py +18 -15
  228. sglang/srt/server_args.py +307 -80
  229. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  230. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  231. sglang/srt/speculative/eagle_worker.py +216 -120
  232. sglang/srt/speculative/spec_info.py +5 -0
  233. sglang/srt/speculative/standalone_worker.py +109 -0
  234. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  235. sglang/srt/utils.py +96 -7
  236. sglang/srt/weight_sync/utils.py +1 -1
  237. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  238. sglang/test/few_shot_gsm8k.py +1 -0
  239. sglang/test/runners.py +4 -0
  240. sglang/test/test_cutlass_moe.py +24 -6
  241. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  242. sglang/test/test_disaggregation_utils.py +66 -0
  243. sglang/test/test_utils.py +25 -1
  244. sglang/utils.py +5 -0
  245. sglang/version.py +1 -1
  246. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
  247. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
  248. sglang/srt/disaggregation/launch_lb.py +0 -131
  249. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  250. sglang/srt/reasoning_parser.py +0 -553
  251. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  252. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  253. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  254. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  255. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  256. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,331 @@
1
+ # Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/utils.py
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import contextlib
5
+ import functools
6
+ import logging
7
+ import os
8
+ import sys
9
+ from enum import Enum
10
+ from functools import lru_cache
11
+ from typing import Any, Callable, Dict, Literal, Optional, Tuple
12
+
13
+ import torch
14
+ import triton
15
+ from packaging import version
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1"
20
+ FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1"
21
+
22
+
23
+ @lru_cache(maxsize=1)
24
+ def check_environments():
25
+ """
26
+ Checks the current operating system, Triton version, and Python version,
27
+ issuing warnings if they don't meet recommendations.
28
+ This function's body only runs once due to lru_cache.
29
+ """
30
+ # Check Operating System
31
+ if sys.platform == "win32":
32
+ logger.warning(
33
+ "Detected Windows operating system. Triton does not have an official Windows release, "
34
+ "thus FLA will not be adapted for Windows, and any potential errors will not be fixed. "
35
+ "Please consider using a Linux environment for compatibility."
36
+ )
37
+
38
+ triton_version = version.parse(triton.__version__)
39
+ required_triton_version = version.parse("3.2.0")
40
+
41
+ if triton_version < required_triton_version:
42
+ logger.warning(
43
+ f"Current Triton version {triton_version} is below the recommended 3.2.0 version. "
44
+ "Errors may occur and these issues will not be fixed. "
45
+ "Please consider upgrading Triton."
46
+ )
47
+
48
+ # Check Python version
49
+ py_version = version.parse(f"{sys.version_info.major}.{sys.version_info.minor}")
50
+ required_py_version = version.parse("3.11")
51
+
52
+ if py_version < required_py_version:
53
+ logger.warning(
54
+ f"Current Python version {py_version} is below the recommended 3.11 version. "
55
+ "It is recommended to upgrade to Python 3.11 or higher for the best experience."
56
+ )
57
+
58
+ return None
59
+
60
+
61
+ check_environments()
62
+
63
+
64
+ def get_abs_err(x, y):
65
+ return (x.detach() - y.detach()).flatten().abs().max().item()
66
+
67
+
68
+ def get_err_ratio(x, y):
69
+ err = (x.detach() - y.detach()).flatten().square().mean().sqrt().item()
70
+ base = (x.detach()).flatten().square().mean().sqrt().item()
71
+ return err / (base + 1e-8)
72
+
73
+
74
+ def assert_close(prefix, ref, tri, ratio, warning=False, err_atol=1e-6):
75
+ abs_atol = get_abs_err(ref, tri)
76
+ msg = f"{prefix} diff: {abs_atol:.6f} ratio: {get_err_ratio(ref, tri):.6f}"
77
+ logger.info(msg)
78
+ error_rate = get_err_ratio(ref, tri)
79
+ if abs_atol <= err_atol:
80
+ return
81
+ if warning or (FLA_CI_ENV and (error_rate < 0.01 or abs_atol <= 0.3)):
82
+ if error_rate > ratio:
83
+ import warnings
84
+
85
+ warnings.warn(msg)
86
+ else:
87
+ assert error_rate < ratio, msg
88
+
89
+
90
+ SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0"))
91
+
92
+
93
+ def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
94
+ """
95
+ A decorator that caches the most recent results of a function with tensor inputs.
96
+ This decorator will store the output of the decorated function for the most recent set of input tensors.
97
+ The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed.
98
+ Args:
99
+ fn (Callable[..., torch.Tensor]):
100
+ The function to be decorated. It should take tensor inputs and return tensor outputs.
101
+ Returns:
102
+ Callable[..., torch.Tensor]:
103
+ A wrapped version of the input function with single-entry caching.
104
+ """
105
+
106
+ cache_entries: Tuple[Optional[Tuple], Optional[Dict], Any] = []
107
+ cache_size = 4
108
+
109
+ @functools.wraps(fn)
110
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
111
+ nonlocal cache_entries, cache_size
112
+ for i, entry in enumerate(cache_entries):
113
+ last_args, last_kwargs, last_result = entry
114
+ if len(args) == len(last_args) and len(kwargs) == len(last_kwargs):
115
+ if all(a is b for a, b in zip(args, last_args)) and all(
116
+ k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()
117
+ ):
118
+ cache_entries = (
119
+ cache_entries[:i]
120
+ + cache_entries[i + 1 :]
121
+ + [(args, kwargs, last_result)]
122
+ )
123
+ return last_result
124
+
125
+ result = fn(*args, **kwargs)
126
+
127
+ if len(cache_entries) >= cache_size:
128
+ cache_entries = cache_entries[1:]
129
+ cache_entries.append((args, kwargs, result))
130
+ return result
131
+
132
+ return wrapper
133
+
134
+
135
+ def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
136
+ """
137
+ A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
138
+ """
139
+
140
+ @functools.wraps(fn)
141
+ def wrapper(*args, **kwargs):
142
+ contiguous_args = (
143
+ i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args
144
+ )
145
+ contiguous_kwargs = {
146
+ k: (v if not isinstance(v, torch.Tensor) else v.contiguous())
147
+ for k, v in kwargs.items()
148
+ }
149
+
150
+ tensor = None
151
+ for arg in args:
152
+ if isinstance(arg, torch.Tensor):
153
+ tensor = arg
154
+ break
155
+ if tensor is None:
156
+ for value in kwargs.values():
157
+ if isinstance(value, torch.Tensor):
158
+ tensor = value
159
+ break
160
+
161
+ if tensor is not None:
162
+ ctx = custom_device_ctx(tensor.device.index)
163
+ else:
164
+ ctx = contextlib.nullcontext()
165
+
166
+ with ctx:
167
+ return fn(*contiguous_args, **contiguous_kwargs)
168
+
169
+ return wrapper
170
+
171
+
172
+ contiguous = input_guard
173
+
174
+
175
+ def require_version(version, hint):
176
+ """
177
+ Perform a runtime check of the dependency versions, using the exact same syntax used by pip.
178
+ """
179
+
180
+ def decorator(fn):
181
+ @functools.wraps(fn)
182
+ def wrapper(ctx, *args, **kwargs):
183
+ from transformers.utils.versions import require_version
184
+
185
+ require_version(version, hint)
186
+ return fn(
187
+ ctx,
188
+ *(
189
+ i if not isinstance(i, torch.Tensor) else i.contiguous()
190
+ for i in args
191
+ ),
192
+ **{
193
+ k: (v if not isinstance(v, torch.Tensor) else v.contiguous())
194
+ for k, v in kwargs.items()
195
+ },
196
+ )
197
+
198
+ return wrapper
199
+
200
+ return decorator
201
+
202
+
203
+ def checkpoint(fn):
204
+ def wrapper(*args, **kwargs):
205
+ return torch.utils.checkpoint.checkpoint(fn, *args, **kwargs)
206
+
207
+ return wrapper
208
+
209
+
210
+ @lru_cache(maxsize=None)
211
+ def check_pytorch_version(version_s: str = "2.4") -> bool:
212
+ return version.parse(torch.__version__) >= version.parse(version_s)
213
+
214
+
215
+ def _cpu_device_warning():
216
+ import warnings
217
+
218
+ warnings.warn(
219
+ ("Triton is not supported on current platform, roll back to CPU."), stacklevel=1
220
+ )
221
+
222
+
223
+ @lru_cache(maxsize=None)
224
+ def get_multiprocessor_count(tensor_idx: int = 0) -> int:
225
+ try:
226
+ return triton.runtime.driver.active.utils.get_device_properties(tensor_idx)[
227
+ "multiprocessor_count"
228
+ ]
229
+ except BaseException:
230
+ _cpu_device_warning()
231
+ return -1
232
+
233
+
234
+ @lru_cache(maxsize=None)
235
+ def get_available_device() -> str:
236
+ try:
237
+ return triton.runtime.driver.active.get_current_target().backend
238
+ except BaseException:
239
+ _cpu_device_warning()
240
+ return "cpu"
241
+
242
+
243
+ @lru_cache(maxsize=None)
244
+ def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]:
245
+ device = get_available_device()
246
+ if device == "cuda":
247
+ return "nvidia"
248
+ elif device == "hip":
249
+ return "amd"
250
+ elif device == "xpu":
251
+ return "intel"
252
+ else:
253
+ return device
254
+
255
+
256
+ # For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
257
+ # However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
258
+ # Therefore, we need to check the triton backend to determine the actual GPU vendor.
259
+ device = get_available_device() if get_available_device() != "hip" else "cuda"
260
+ device_torch_lib = getattr(torch, device)
261
+ device_platform = _check_platform()
262
+
263
+ is_amd = device_platform == "amd"
264
+ is_intel = device_platform == "intel"
265
+ is_nvidia = device_platform == "nvidia"
266
+ is_intel_alchemist = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name(0)
267
+ is_nvidia_hopper = is_nvidia and (
268
+ "NVIDIA H" in torch.cuda.get_device_name(0)
269
+ or torch.cuda.get_device_capability()[0] >= 9
270
+ )
271
+ use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1"
272
+
273
+ # Nvidia Ampere or newer, haven't check AMD and intel yet.
274
+ is_tf32_supported = is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8
275
+ is_gather_supported = hasattr(triton.language, "gather")
276
+
277
+
278
+ def get_all_max_shared_mem():
279
+ try:
280
+ return [
281
+ triton.runtime.driver.active.utils.get_device_properties(i)[
282
+ "max_shared_mem"
283
+ ]
284
+ for i in range(device_torch_lib.device_count())
285
+ ]
286
+ except BaseException:
287
+ _cpu_device_warning()
288
+ return [-1]
289
+
290
+
291
+ class Backend(Enum):
292
+ ADA = 101376 # RTX 4090
293
+ AMPERE = 166912 # A100
294
+ HOPPER = 232448 # H100
295
+ DEFAULT = 102400 # Default
296
+
297
+ @classmethod
298
+ def get_shared_memory(cls, arch: str) -> int:
299
+ try:
300
+ return cls[arch.upper()].value
301
+ except KeyError:
302
+ return cls.DEFAULT.value
303
+
304
+
305
+ @lru_cache(maxsize=None)
306
+ def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
307
+ try:
308
+ device_shared_mem_list = get_all_max_shared_mem()
309
+ max_shared_memory = device_shared_mem_list[tensor_idx]
310
+ return max_shared_memory >= Backend.get_shared_memory(arch)
311
+ except Exception:
312
+ return False
313
+
314
+
315
+ if check_pytorch_version("2.4"):
316
+ device = "cuda" if device == "cpu" else device
317
+ autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=device)
318
+ autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=device)
319
+
320
+ def custom_device_ctx(index: int):
321
+ return device_torch_lib.device(index)
322
+
323
+ else:
324
+ assert (
325
+ device == "cuda"
326
+ ), "Only cuda device is supported for PyTorch version < 2.4.0."
327
+ autocast_custom_fwd = device_torch_lib.amp.custom_fwd
328
+ autocast_custom_bwd = device_torch_lib.amp.custom_bwd
329
+
330
+ def custom_device_ctx(index: int):
331
+ return torch.cuda.device(index)
@@ -0,0 +1,158 @@
1
+ # Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gated_delta_rule/wy_fast.py
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
4
+
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+
11
+ from sglang.srt.layers.attention.fla.index import prepare_chunk_indices
12
+ from sglang.srt.layers.attention.fla.op import safe_exp
13
+ from sglang.srt.layers.attention.fla.utils import check_shared_mem
14
+
15
+
16
+ @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
17
+ # @triton.autotune(
18
+ # configs=[
19
+ # triton.Config({}, num_warps=num_warps, num_stages=num_stages)
20
+ # for num_warps in [2, 4, 8]
21
+ # for num_stages in [2, 3, 4]
22
+ # ],
23
+ # key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"],
24
+ # )
25
+ @triton.jit(do_not_specialize=["T"])
26
+ def recompute_w_u_fwd_kernel(
27
+ k,
28
+ v,
29
+ beta,
30
+ w,
31
+ u,
32
+ A,
33
+ g,
34
+ cu_seqlens,
35
+ chunk_indices,
36
+ T,
37
+ H: tl.constexpr,
38
+ Hg: tl.constexpr,
39
+ K: tl.constexpr,
40
+ V: tl.constexpr,
41
+ BT: tl.constexpr,
42
+ BK: tl.constexpr,
43
+ BV: tl.constexpr,
44
+ IS_VARLEN: tl.constexpr,
45
+ ):
46
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
47
+ i_b, i_h = i_bh // H, i_bh % H
48
+ if IS_VARLEN:
49
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(
50
+ chunk_indices + i_t * 2 + 1
51
+ ).to(tl.int32)
52
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(
53
+ cu_seqlens + i_n + 1
54
+ ).to(tl.int32)
55
+ T = eos - bos
56
+ else:
57
+ bos, eos = i_b * T, i_b * T + T
58
+ p_beta = tl.make_block_ptr(
59
+ beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
60
+ )
61
+ p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
62
+ p_A = tl.make_block_ptr(
63
+ A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
64
+ )
65
+ b_beta = tl.load(p_beta, boundary_check=(0,))
66
+ b_A = tl.load(p_A, boundary_check=(0, 1))
67
+ b_g = tl.exp(tl.load(p_g, boundary_check=(0,)))
68
+
69
+ for i_v in range(tl.cdiv(V, BV)):
70
+ p_v = tl.make_block_ptr(
71
+ v + (bos * H + i_h) * V,
72
+ (T, V),
73
+ (H * V, 1),
74
+ (i_t * BT, i_v * BV),
75
+ (BT, BV),
76
+ (1, 0),
77
+ )
78
+ p_u = tl.make_block_ptr(
79
+ u + (bos * H + i_h) * V,
80
+ (T, V),
81
+ (H * V, 1),
82
+ (i_t * BT, i_v * BV),
83
+ (BT, BV),
84
+ (1, 0),
85
+ )
86
+ b_v = tl.load(p_v, boundary_check=(0, 1))
87
+ b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
88
+ b_u = tl.dot(b_A, b_vb, allow_tf32=False)
89
+ tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
90
+
91
+ for i_k in range(tl.cdiv(K, BK)):
92
+ p_k = tl.make_block_ptr(
93
+ k + (bos * Hg + i_h // (H // Hg)) * K,
94
+ (T, K),
95
+ (Hg * K, 1),
96
+ (i_t * BT, i_k * BK),
97
+ (BT, BK),
98
+ (1, 0),
99
+ )
100
+ p_w = tl.make_block_ptr(
101
+ w + (bos * H + i_h) * K,
102
+ (T, K),
103
+ (H * K, 1),
104
+ (i_t * BT, i_k * BK),
105
+ (BT, BK),
106
+ (1, 0),
107
+ )
108
+ b_k = tl.load(p_k, boundary_check=(0, 1))
109
+ b_kb = (b_k * b_beta[:, None] * b_g[:, None]).to(b_k.dtype)
110
+ b_w = tl.dot(b_A, b_kb)
111
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
112
+
113
+
114
+ def recompute_w_u_fwd(
115
+ k: torch.Tensor,
116
+ v: torch.Tensor,
117
+ beta: torch.Tensor,
118
+ g_cumsum: torch.Tensor,
119
+ A: torch.Tensor,
120
+ cu_seqlens: Optional[torch.LongTensor],
121
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
122
+ B, T, Hg, K, V = *k.shape, v.shape[-1]
123
+ H = v.shape[-2]
124
+ BT = A.shape[-1]
125
+
126
+ chunk_indices = (
127
+ prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
128
+ )
129
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
130
+ BK = 64
131
+ BV = 64
132
+ u = torch.empty_like(v)
133
+ w = k.new_empty(B, T, H, K)
134
+ recompute_w_u_fwd_kernel[(NT, B * H)](
135
+ k=k,
136
+ v=v,
137
+ beta=beta,
138
+ w=w,
139
+ u=u,
140
+ A=A,
141
+ g=g_cumsum,
142
+ cu_seqlens=cu_seqlens,
143
+ chunk_indices=chunk_indices,
144
+ T=T,
145
+ H=H,
146
+ Hg=Hg,
147
+ K=K,
148
+ V=V,
149
+ BT=BT,
150
+ BK=BK,
151
+ BV=BV,
152
+ num_warps=4,
153
+ num_stages=3,
154
+ )
155
+ return w, u
156
+
157
+
158
+ fwd_recompute_w_u = recompute_w_u_fwd