sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (256) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +89 -54
  3. sglang/bench_serving.py +437 -40
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/profiler.py +0 -1
  6. sglang/srt/configs/__init__.py +4 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/longcat_flash.py +104 -0
  9. sglang/srt/configs/model_config.py +37 -7
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +1 -1
  12. sglang/srt/connector/base_connector.py +1 -2
  13. sglang/srt/connector/redis.py +2 -2
  14. sglang/srt/connector/serde/__init__.py +1 -1
  15. sglang/srt/connector/serde/safe_serde.py +4 -3
  16. sglang/srt/custom_op.py +11 -1
  17. sglang/srt/debug_utils/dump_comparator.py +81 -44
  18. sglang/srt/debug_utils/dump_loader.py +97 -0
  19. sglang/srt/debug_utils/dumper.py +11 -3
  20. sglang/srt/debug_utils/text_comparator.py +73 -11
  21. sglang/srt/disaggregation/ascend/conn.py +75 -0
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +6 -4
  25. sglang/srt/disaggregation/fake/conn.py +1 -1
  26. sglang/srt/disaggregation/mini_lb.py +6 -420
  27. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  28. sglang/srt/disaggregation/nixl/conn.py +180 -16
  29. sglang/srt/disaggregation/prefill.py +6 -4
  30. sglang/srt/disaggregation/utils.py +5 -50
  31. sglang/srt/distributed/parallel_state.py +94 -58
  32. sglang/srt/entrypoints/engine.py +34 -14
  33. sglang/srt/entrypoints/http_server.py +172 -47
  34. sglang/srt/entrypoints/openai/protocol.py +90 -27
  35. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  36. sglang/srt/entrypoints/openai/serving_chat.py +82 -26
  37. sglang/srt/entrypoints/openai/serving_completions.py +25 -4
  38. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  39. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  40. sglang/srt/eplb/eplb_manager.py +28 -4
  41. sglang/srt/eplb/expert_distribution.py +55 -15
  42. sglang/srt/eplb/expert_location.py +8 -3
  43. sglang/srt/eplb/expert_location_updater.py +1 -1
  44. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  45. sglang/srt/function_call/ebnf_composer.py +11 -9
  46. sglang/srt/function_call/function_call_parser.py +2 -0
  47. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  48. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  49. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  50. sglang/srt/hf_transformers_utils.py +28 -7
  51. sglang/srt/layers/activation.py +44 -9
  52. sglang/srt/layers/attention/aiter_backend.py +93 -68
  53. sglang/srt/layers/attention/ascend_backend.py +381 -136
  54. sglang/srt/layers/attention/fla/chunk.py +242 -0
  55. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  56. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  57. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  58. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  59. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  60. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  61. sglang/srt/layers/attention/fla/index.py +37 -0
  62. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  63. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  64. sglang/srt/layers/attention/fla/op.py +66 -0
  65. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  66. sglang/srt/layers/attention/fla/utils.py +331 -0
  67. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  68. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  69. sglang/srt/layers/attention/flashinfer_backend.py +11 -6
  70. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
  71. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  72. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  73. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  74. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  75. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  76. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  77. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  78. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  79. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  80. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  81. sglang/srt/layers/communicator.py +45 -8
  82. sglang/srt/layers/layernorm.py +54 -12
  83. sglang/srt/layers/logits_processor.py +10 -3
  84. sglang/srt/layers/moe/__init__.py +2 -1
  85. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  86. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  87. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  88. sglang/srt/layers/moe/ep_moe/layer.py +111 -56
  89. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  90. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  94. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  100. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  101. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  102. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  103. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  104. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  105. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  106. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  107. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  108. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  109. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  110. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  111. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  112. sglang/srt/layers/moe/topk.py +43 -12
  113. sglang/srt/layers/moe/utils.py +6 -5
  114. sglang/srt/layers/quantization/awq.py +19 -7
  115. sglang/srt/layers/quantization/base_config.py +11 -6
  116. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  119. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  121. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
  122. sglang/srt/layers/quantization/fp8.py +78 -48
  123. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  124. sglang/srt/layers/quantization/fp8_utils.py +45 -31
  125. sglang/srt/layers/quantization/gptq.py +25 -17
  126. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  127. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  128. sglang/srt/layers/quantization/mxfp4.py +93 -68
  129. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  130. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  131. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  132. sglang/srt/layers/quantization/quark/utils.py +97 -0
  133. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  134. sglang/srt/layers/quantization/unquant.py +135 -47
  135. sglang/srt/layers/quantization/utils.py +13 -0
  136. sglang/srt/layers/quantization/w4afp8.py +60 -42
  137. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  138. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  139. sglang/srt/layers/rocm_linear_utils.py +44 -0
  140. sglang/srt/layers/rotary_embedding.py +28 -19
  141. sglang/srt/layers/sampler.py +29 -5
  142. sglang/srt/layers/utils.py +0 -14
  143. sglang/srt/lora/backend/base_backend.py +50 -8
  144. sglang/srt/lora/backend/triton_backend.py +90 -2
  145. sglang/srt/lora/layers.py +32 -0
  146. sglang/srt/lora/lora.py +4 -1
  147. sglang/srt/lora/lora_manager.py +35 -112
  148. sglang/srt/lora/mem_pool.py +24 -10
  149. sglang/srt/lora/utils.py +18 -9
  150. sglang/srt/managers/cache_controller.py +396 -365
  151. sglang/srt/managers/data_parallel_controller.py +30 -15
  152. sglang/srt/managers/detokenizer_manager.py +18 -2
  153. sglang/srt/managers/disagg_service.py +46 -0
  154. sglang/srt/managers/io_struct.py +190 -11
  155. sglang/srt/managers/mm_utils.py +6 -1
  156. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  157. sglang/srt/managers/schedule_batch.py +27 -44
  158. sglang/srt/managers/schedule_policy.py +4 -3
  159. sglang/srt/managers/scheduler.py +148 -122
  160. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  161. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  162. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  163. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  164. sglang/srt/managers/template_manager.py +3 -3
  165. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  166. sglang/srt/managers/tokenizer_manager.py +77 -480
  167. sglang/srt/managers/tp_worker.py +16 -4
  168. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  169. sglang/srt/mem_cache/allocator.py +1 -1
  170. sglang/srt/mem_cache/chunk_cache.py +1 -1
  171. sglang/srt/mem_cache/hicache_storage.py +53 -40
  172. sglang/srt/mem_cache/hiradix_cache.py +196 -104
  173. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  174. sglang/srt/mem_cache/memory_pool.py +395 -53
  175. sglang/srt/mem_cache/memory_pool_host.py +27 -19
  176. sglang/srt/mem_cache/radix_cache.py +6 -6
  177. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  178. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  179. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  180. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  181. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
  182. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  183. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  184. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
  185. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  186. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  187. sglang/srt/metrics/collector.py +484 -63
  188. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  189. sglang/srt/metrics/utils.py +48 -0
  190. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  191. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  192. sglang/srt/model_executor/forward_batch_info.py +72 -18
  193. sglang/srt/model_executor/model_runner.py +190 -32
  194. sglang/srt/model_loader/__init__.py +9 -3
  195. sglang/srt/model_loader/loader.py +33 -28
  196. sglang/srt/model_loader/utils.py +12 -0
  197. sglang/srt/model_loader/weight_utils.py +2 -1
  198. sglang/srt/models/deepseek_v2.py +323 -53
  199. sglang/srt/models/gemma3n_mm.py +1 -1
  200. sglang/srt/models/glm4_moe.py +10 -1
  201. sglang/srt/models/glm4v.py +4 -2
  202. sglang/srt/models/gpt_oss.py +7 -19
  203. sglang/srt/models/internvl.py +28 -0
  204. sglang/srt/models/llama4.py +9 -0
  205. sglang/srt/models/llama_eagle3.py +17 -0
  206. sglang/srt/models/longcat_flash.py +1026 -0
  207. sglang/srt/models/longcat_flash_nextn.py +699 -0
  208. sglang/srt/models/minicpmv.py +165 -3
  209. sglang/srt/models/mllama4.py +25 -0
  210. sglang/srt/models/opt.py +637 -0
  211. sglang/srt/models/qwen2.py +33 -3
  212. sglang/srt/models/qwen2_5_vl.py +91 -42
  213. sglang/srt/models/qwen2_moe.py +79 -14
  214. sglang/srt/models/qwen3.py +8 -2
  215. sglang/srt/models/qwen3_moe.py +39 -8
  216. sglang/srt/models/qwen3_next.py +1039 -0
  217. sglang/srt/models/qwen3_next_mtp.py +109 -0
  218. sglang/srt/models/torch_native_llama.py +1 -1
  219. sglang/srt/models/transformers.py +1 -1
  220. sglang/srt/multimodal/processors/base_processor.py +4 -2
  221. sglang/srt/multimodal/processors/glm4v.py +9 -9
  222. sglang/srt/multimodal/processors/internvl.py +141 -129
  223. sglang/srt/{conversation.py → parser/conversation.py} +38 -5
  224. sglang/srt/parser/harmony_parser.py +588 -0
  225. sglang/srt/parser/reasoning_parser.py +309 -0
  226. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  227. sglang/srt/sampling/sampling_batch_info.py +18 -15
  228. sglang/srt/server_args.py +307 -80
  229. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  230. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  231. sglang/srt/speculative/eagle_worker.py +216 -120
  232. sglang/srt/speculative/spec_info.py +5 -0
  233. sglang/srt/speculative/standalone_worker.py +109 -0
  234. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  235. sglang/srt/utils.py +96 -7
  236. sglang/srt/weight_sync/utils.py +1 -1
  237. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  238. sglang/test/few_shot_gsm8k.py +1 -0
  239. sglang/test/runners.py +4 -0
  240. sglang/test/test_cutlass_moe.py +24 -6
  241. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  242. sglang/test/test_disaggregation_utils.py +66 -0
  243. sglang/test/test_utils.py +25 -1
  244. sglang/utils.py +5 -0
  245. sglang/version.py +1 -1
  246. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
  247. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
  248. sglang/srt/disaggregation/launch_lb.py +0 -131
  249. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  250. sglang/srt/reasoning_parser.py +0 -553
  251. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  252. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  253. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  254. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  255. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  256. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,232 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from sglang.srt.layers.attention.fla.utils import input_guard
8
+
9
+
10
+ @triton.heuristics(
11
+ {
12
+ "USE_INITIAL_STATE": lambda args: args["h0_source"] is not None,
13
+ "IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
14
+ }
15
+ )
16
+ @triton.jit(do_not_specialize=["T"])
17
+ def fused_sigmoid_gating_delta_rule_update_kernel(
18
+ A_log,
19
+ a,
20
+ dt_bias,
21
+ softplus_beta,
22
+ softplus_threshold,
23
+ q,
24
+ k,
25
+ v,
26
+ b,
27
+ o,
28
+ h0_source,
29
+ h0_indices,
30
+ cu_seqlens,
31
+ scale,
32
+ T,
33
+ B: tl.constexpr,
34
+ H: tl.constexpr,
35
+ HV: tl.constexpr,
36
+ K: tl.constexpr,
37
+ V: tl.constexpr,
38
+ BK: tl.constexpr,
39
+ BV: tl.constexpr,
40
+ USE_INITIAL_STATE: tl.constexpr,
41
+ USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
42
+ IS_VARLEN: tl.constexpr,
43
+ ):
44
+ """
45
+ Fused kernel that combines sigmoid gating computation with recurrent delta rule update.
46
+ """
47
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
48
+ i_n, i_hv = i_nh // HV, i_nh % HV
49
+ i_h = i_hv // (HV // H)
50
+
51
+ if IS_VARLEN:
52
+ bos, eos = (
53
+ tl.load(cu_seqlens + i_n).to(tl.int64),
54
+ tl.load(cu_seqlens + i_n + 1).to(tl.int64),
55
+ )
56
+ all = T
57
+ T = eos - bos
58
+ else:
59
+ bos, eos = i_n * T, i_n * T + T
60
+ all = B * T
61
+
62
+ o_k = i_k * BK + tl.arange(0, BK)
63
+ o_v = i_v * BV + tl.arange(0, BV)
64
+
65
+ p_q = q + (bos * H + i_h) * K + o_k
66
+ p_k = k + (bos * H + i_h) * K + o_k
67
+ p_v = v + (bos * HV + i_hv) * V + o_v
68
+ p_b = b + bos * HV + i_hv
69
+ p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
70
+
71
+ # Gating computation pointers
72
+ p_A_log = A_log + i_hv
73
+ p_a = a + bos * HV + i_hv
74
+ p_dt_bias = dt_bias + i_hv
75
+
76
+ mask_k = o_k < K
77
+ mask_v = o_v < V
78
+ mask_h = mask_k[:, None] & mask_v[None, :]
79
+
80
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
81
+ if USE_INITIAL_STATE:
82
+ idx = tl.load(h0_indices + i_n)
83
+ if idx >= 0:
84
+ p_h0 = (
85
+ h0_source
86
+ + idx * HV * K * V
87
+ + i_hv * K * V
88
+ + o_k[:, None] * V
89
+ + o_v[None, :]
90
+ )
91
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
92
+
93
+ for _ in range(0, T):
94
+ # Load inputs
95
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
96
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
97
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
98
+ b_b = tl.load(p_b).to(tl.float32)
99
+
100
+ # Compute sigmoid gating
101
+ # Load gating parameters
102
+ b_A_log = tl.load(p_A_log).to(tl.float32)
103
+ b_a = tl.load(p_a).to(tl.float32)
104
+ b_dt_bias = tl.load(p_dt_bias).to(tl.float32)
105
+
106
+ # Compute g = -exp(A_log) * softplus(a + dt_bias)
107
+ x = b_a + b_dt_bias
108
+ beta_x = softplus_beta * x
109
+ # Apply softplus with numerical stability
110
+ softplus_x = tl.where(
111
+ beta_x <= softplus_threshold,
112
+ (1.0 / softplus_beta) * tl.log(1.0 + tl.exp(beta_x)),
113
+ x,
114
+ )
115
+ b_g = -tl.exp(b_A_log) * softplus_x
116
+
117
+ # Compute beta = sigmoid(b)
118
+ b_beta = 1.0 / (1.0 + tl.exp(-b_b))
119
+
120
+ # Apply L2 normalization if enabled
121
+ if USE_QK_L2NORM_IN_KERNEL:
122
+ b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
123
+ b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
124
+
125
+ b_q = b_q * scale
126
+
127
+ # Apply gating to hidden state: h *= exp(g)
128
+ b_h *= tl.exp(b_g)
129
+
130
+ # Delta rule: v -= sum(h * k, dim=0)
131
+ b_v -= tl.sum(b_h * b_k[:, None], 0)
132
+
133
+ # Apply beta gating: v *= beta
134
+ b_v *= b_beta
135
+
136
+ # Update hidden state: h += k[:, None] * v[None, :]
137
+ b_h += b_k[:, None] * b_v[None, :]
138
+
139
+ # Compute output: o = sum(h * q, dim=0)
140
+ b_o = tl.sum(b_h * b_q[:, None], 0)
141
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
142
+
143
+ # Update pointers for next timestep
144
+ p_q += H * K
145
+ p_k += H * K
146
+ p_o += HV * V
147
+ p_v += HV * V
148
+ p_b += HV
149
+ p_a += HV
150
+
151
+ # Store final state back to h0_source with bounds checking
152
+ if USE_INITIAL_STATE:
153
+ idx = tl.load(h0_indices + i_n)
154
+ if idx >= 0:
155
+ p_h0 = (
156
+ h0_source
157
+ + idx * HV * K * V
158
+ + i_hv * K * V
159
+ + o_k[:, None] * V
160
+ + o_v[None, :]
161
+ )
162
+ tl.store(p_h0, b_h.to(p_h0.dtype.element_ty), mask=mask_h)
163
+
164
+
165
+ @input_guard
166
+ def fused_sigmoid_gating_delta_rule_update(
167
+ A_log: torch.Tensor,
168
+ a: torch.Tensor,
169
+ dt_bias: torch.Tensor,
170
+ softplus_beta: float,
171
+ softplus_threshold: float,
172
+ q: torch.Tensor,
173
+ k: torch.Tensor,
174
+ v: torch.Tensor,
175
+ b: torch.Tensor,
176
+ initial_state_source: torch.Tensor,
177
+ initial_state_indices: torch.Tensor,
178
+ scale: Optional[float] = None,
179
+ use_qk_l2norm_in_kernel: bool = False,
180
+ cu_seqlens: Optional[torch.Tensor] = None,
181
+ ):
182
+ """
183
+ Fused triton implementation of sigmoid gating delta rule update.
184
+ This function uses a single fused kernel that combines both sigmoid gating computation
185
+ and the recurrent delta rule update for better performance.
186
+ """
187
+ B, T, H, K, V = *k.shape, v.shape[-1]
188
+ HV = v.shape[2]
189
+ N = B if cu_seqlens is None else len(cu_seqlens) - 1
190
+ BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)
191
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
192
+ assert NK == 1, "NK > 1 is not supported yet"
193
+ num_stages = 3
194
+ num_warps = 1
195
+
196
+ if scale is None:
197
+ scale = k.shape[-1] ** -0.5
198
+ else:
199
+ assert scale > 0, "scale must be positive"
200
+
201
+ o = q.new_empty(NK, *v.shape)
202
+ grid = (NK, NV, N * HV)
203
+
204
+ fused_sigmoid_gating_delta_rule_update_kernel[grid](
205
+ A_log=A_log,
206
+ a=a,
207
+ dt_bias=dt_bias,
208
+ softplus_beta=softplus_beta,
209
+ softplus_threshold=softplus_threshold,
210
+ q=q,
211
+ k=k,
212
+ v=v,
213
+ b=b,
214
+ o=o,
215
+ h0_source=initial_state_source,
216
+ h0_indices=initial_state_indices,
217
+ cu_seqlens=cu_seqlens,
218
+ scale=scale,
219
+ T=T,
220
+ B=B,
221
+ H=H,
222
+ HV=HV,
223
+ K=K,
224
+ V=V,
225
+ BK=BK,
226
+ BV=BV,
227
+ USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
228
+ num_warps=num_warps,
229
+ num_stages=num_stages,
230
+ )
231
+ o = o.squeeze(0)
232
+ return o
@@ -0,0 +1,37 @@
1
+ # Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/index.py
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from sglang.srt.layers.attention.fla.utils import tensor_cache
11
+
12
+
13
+ @tensor_cache
14
+ def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
15
+ return cu_seqlens[1:] - cu_seqlens[:-1]
16
+
17
+
18
+ @tensor_cache
19
+ def prepare_chunk_indices(
20
+ cu_seqlens: torch.LongTensor, chunk_size: int
21
+ ) -> torch.LongTensor:
22
+ indices = torch.cat(
23
+ [
24
+ torch.arange(n)
25
+ for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
26
+ ]
27
+ )
28
+ return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens)
29
+
30
+
31
+ @tensor_cache
32
+ def prepare_chunk_offsets(
33
+ cu_seqlens: torch.LongTensor, chunk_size: int
34
+ ) -> torch.LongTensor:
35
+ return torch.cat(
36
+ [cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]
37
+ ).cumsum(-1)
@@ -0,0 +1,150 @@
1
+ # Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/l2norm.py
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
4
+
5
+ from typing import Optional
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import triton
10
+ import triton.language as tl
11
+
12
+ from sglang.srt.layers.attention.fla.utils import input_guard
13
+
14
+ BT_LIST = [8, 16, 32, 64, 128]
15
+
16
+
17
+ # @triton.autotune(
18
+ # configs=[
19
+ # triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32]
20
+ # ],
21
+ # key=["D"],
22
+ # )
23
+ @triton.jit
24
+ def l2norm_fwd_kernel1(
25
+ x,
26
+ y,
27
+ D,
28
+ BD: tl.constexpr,
29
+ eps,
30
+ ):
31
+ i_t = tl.program_id(0)
32
+ x += i_t * D
33
+ y += i_t * D
34
+ # Compute mean and variance
35
+ cols = tl.arange(0, BD)
36
+ mask = cols < D
37
+ b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32)
38
+ b_var = tl.sum(b_x * b_x, axis=0)
39
+ b_rstd = 1 / tl.sqrt(b_var + eps)
40
+ # tl.store(Rstd + i_t, rstd)
41
+ # Normalize and apply linear transformation
42
+ b_y = b_x * b_rstd
43
+ tl.store(y + cols, b_y, mask=mask)
44
+
45
+
46
+ # @triton.autotune(
47
+ # configs=[
48
+ # triton.Config({"BT": BT}, num_warps=num_warps)
49
+ # for num_warps in [1, 2, 4, 8, 16]
50
+ # for BT in BT_LIST
51
+ # ],
52
+ # key=["D", "NB"],
53
+ # )
54
+ @triton.jit
55
+ def l2norm_fwd_kernel(
56
+ x,
57
+ y,
58
+ eps,
59
+ NB: tl.constexpr,
60
+ T: tl.constexpr,
61
+ D: tl.constexpr,
62
+ BT: tl.constexpr,
63
+ BD: tl.constexpr,
64
+ ):
65
+ i_t = tl.program_id(0)
66
+ p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
67
+ b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
68
+ b_var = tl.sum(b_x * b_x, axis=1)
69
+ b_y = b_x / tl.sqrt(b_var + eps)[:, None]
70
+ p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
71
+ tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1))
72
+
73
+
74
+ def l2norm_fwd(
75
+ x: torch.Tensor, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None
76
+ ):
77
+ x_shape_og = x.shape
78
+ x = x.view(-1, x.shape[-1])
79
+ # allocate output
80
+ if output_dtype is None:
81
+ y = torch.empty_like(x)
82
+ else:
83
+ y = torch.empty_like(x, dtype=output_dtype)
84
+ assert y.stride(-1) == 1
85
+ T, D = x.shape[0], x.shape[-1]
86
+ # rstd = torch.empty((T,), dtype=torch.float32, device=x.device)
87
+ # Less than 64KB per feature: enqueue fused kernel
88
+ MAX_FUSED_SIZE = 65536 // x.element_size()
89
+ BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))
90
+ if D > BD:
91
+ raise RuntimeError("This layer doesn't support feature dim >= 64KB.")
92
+
93
+ if D <= 512:
94
+ NB = triton.cdiv(T, 2048)
95
+
96
+ def grid(meta):
97
+ return (triton.cdiv(T, meta["BT"]),)
98
+
99
+ l2norm_fwd_kernel[grid](
100
+ x,
101
+ y,
102
+ eps,
103
+ NB=NB,
104
+ T=T,
105
+ D=D,
106
+ BD=BD,
107
+ BT=16,
108
+ num_warps=8,
109
+ num_stages=3,
110
+ )
111
+ else:
112
+ l2norm_fwd_kernel1[(T,)](
113
+ x,
114
+ y,
115
+ eps=eps,
116
+ D=D,
117
+ BD=BD,
118
+ num_warps=8,
119
+ num_stages=3,
120
+ )
121
+
122
+ return y.view(x_shape_og)
123
+
124
+
125
+ class L2NormFunction(torch.autograd.Function):
126
+
127
+ @staticmethod
128
+ @input_guard
129
+ def forward(ctx, x, eps=1e-6, output_dtype=None):
130
+ return l2norm_fwd(x, eps, output_dtype)
131
+
132
+
133
+ def l2norm(
134
+ x: torch.Tensor, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None
135
+ ) -> torch.Tensor:
136
+ return L2NormFunction.apply(x, eps, output_dtype)
137
+
138
+
139
+ l2_norm = l2norm
140
+
141
+
142
+ class L2Norm(nn.Module):
143
+
144
+ def __init__(self, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None):
145
+ super().__init__()
146
+ self.eps = eps
147
+ self.output_dtype = output_dtype
148
+
149
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
150
+ return l2norm(x, self.eps, self.output_dtype)