sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +10 -1
  3. sglang/bench_serving.py +251 -26
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/internvl.py +6 -0
  7. sglang/srt/configs/longcat_flash.py +104 -0
  8. sglang/srt/configs/model_config.py +37 -7
  9. sglang/srt/configs/qwen3_next.py +326 -0
  10. sglang/srt/connector/__init__.py +1 -1
  11. sglang/srt/connector/base_connector.py +1 -2
  12. sglang/srt/connector/redis.py +2 -2
  13. sglang/srt/connector/serde/__init__.py +1 -1
  14. sglang/srt/connector/serde/safe_serde.py +4 -3
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/ascend/conn.py +75 -0
  21. sglang/srt/disaggregation/base/conn.py +1 -1
  22. sglang/srt/disaggregation/common/conn.py +15 -12
  23. sglang/srt/disaggregation/decode.py +6 -4
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -420
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +6 -4
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +94 -58
  31. sglang/srt/entrypoints/engine.py +34 -14
  32. sglang/srt/entrypoints/http_server.py +172 -47
  33. sglang/srt/entrypoints/openai/protocol.py +63 -3
  34. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  35. sglang/srt/entrypoints/openai/serving_chat.py +34 -19
  36. sglang/srt/entrypoints/openai/serving_completions.py +10 -4
  37. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  38. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  39. sglang/srt/eplb/eplb_manager.py +28 -4
  40. sglang/srt/eplb/expert_distribution.py +55 -15
  41. sglang/srt/eplb/expert_location.py +8 -3
  42. sglang/srt/eplb/expert_location_updater.py +1 -1
  43. sglang/srt/function_call/ebnf_composer.py +11 -9
  44. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  45. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  46. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  47. sglang/srt/hf_transformers_utils.py +12 -0
  48. sglang/srt/layers/activation.py +44 -9
  49. sglang/srt/layers/attention/aiter_backend.py +93 -68
  50. sglang/srt/layers/attention/ascend_backend.py +250 -112
  51. sglang/srt/layers/attention/fla/chunk.py +242 -0
  52. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  53. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  54. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  55. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  56. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  57. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  58. sglang/srt/layers/attention/fla/index.py +37 -0
  59. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  60. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  61. sglang/srt/layers/attention/fla/op.py +66 -0
  62. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  63. sglang/srt/layers/attention/fla/utils.py +331 -0
  64. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  65. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  66. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  67. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  68. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  69. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  70. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  71. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  72. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  73. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  74. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  75. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  76. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  77. sglang/srt/layers/communicator.py +45 -7
  78. sglang/srt/layers/layernorm.py +54 -12
  79. sglang/srt/layers/logits_processor.py +10 -3
  80. sglang/srt/layers/moe/__init__.py +2 -1
  81. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  82. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  83. sglang/srt/layers/moe/ep_moe/layer.py +110 -49
  84. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  85. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  94. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  95. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  96. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  97. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  98. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  99. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  100. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  101. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  102. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  103. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  104. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  105. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  106. sglang/srt/layers/moe/topk.py +43 -12
  107. sglang/srt/layers/moe/utils.py +6 -5
  108. sglang/srt/layers/quantization/awq.py +19 -7
  109. sglang/srt/layers/quantization/base_config.py +11 -6
  110. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  111. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  112. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  113. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  114. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  115. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  116. sglang/srt/layers/quantization/fp8.py +76 -47
  117. sglang/srt/layers/quantization/fp8_utils.py +43 -29
  118. sglang/srt/layers/quantization/gptq.py +25 -17
  119. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  120. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  121. sglang/srt/layers/quantization/mxfp4.py +77 -45
  122. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  123. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  124. sglang/srt/layers/quantization/quark/utils.py +97 -0
  125. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  126. sglang/srt/layers/quantization/unquant.py +135 -47
  127. sglang/srt/layers/quantization/utils.py +13 -0
  128. sglang/srt/layers/quantization/w4afp8.py +60 -42
  129. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  130. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  131. sglang/srt/layers/rocm_linear_utils.py +44 -0
  132. sglang/srt/layers/rotary_embedding.py +28 -19
  133. sglang/srt/layers/sampler.py +29 -5
  134. sglang/srt/lora/backend/base_backend.py +50 -8
  135. sglang/srt/lora/backend/triton_backend.py +90 -2
  136. sglang/srt/lora/layers.py +32 -0
  137. sglang/srt/lora/lora.py +4 -1
  138. sglang/srt/lora/lora_manager.py +35 -112
  139. sglang/srt/lora/mem_pool.py +24 -10
  140. sglang/srt/lora/utils.py +18 -9
  141. sglang/srt/managers/cache_controller.py +242 -278
  142. sglang/srt/managers/data_parallel_controller.py +30 -15
  143. sglang/srt/managers/detokenizer_manager.py +13 -2
  144. sglang/srt/managers/disagg_service.py +46 -0
  145. sglang/srt/managers/io_struct.py +160 -11
  146. sglang/srt/managers/mm_utils.py +6 -1
  147. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  148. sglang/srt/managers/schedule_batch.py +27 -44
  149. sglang/srt/managers/schedule_policy.py +4 -3
  150. sglang/srt/managers/scheduler.py +90 -115
  151. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  152. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  153. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  154. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  155. sglang/srt/managers/template_manager.py +3 -3
  156. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  157. sglang/srt/managers/tokenizer_manager.py +41 -477
  158. sglang/srt/managers/tp_worker.py +16 -4
  159. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  160. sglang/srt/mem_cache/allocator.py +1 -1
  161. sglang/srt/mem_cache/chunk_cache.py +1 -1
  162. sglang/srt/mem_cache/hicache_storage.py +24 -22
  163. sglang/srt/mem_cache/hiradix_cache.py +184 -101
  164. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  165. sglang/srt/mem_cache/memory_pool.py +324 -41
  166. sglang/srt/mem_cache/memory_pool_host.py +25 -18
  167. sglang/srt/mem_cache/radix_cache.py +5 -6
  168. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  169. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  170. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  171. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  172. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
  173. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  174. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  175. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
  176. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  177. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  178. sglang/srt/metrics/collector.py +484 -63
  179. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  180. sglang/srt/metrics/utils.py +48 -0
  181. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  182. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  183. sglang/srt/model_executor/forward_batch_info.py +72 -18
  184. sglang/srt/model_executor/model_runner.py +189 -31
  185. sglang/srt/model_loader/__init__.py +9 -3
  186. sglang/srt/model_loader/loader.py +33 -28
  187. sglang/srt/model_loader/utils.py +12 -0
  188. sglang/srt/model_loader/weight_utils.py +2 -1
  189. sglang/srt/models/deepseek_v2.py +311 -50
  190. sglang/srt/models/gemma3n_mm.py +1 -1
  191. sglang/srt/models/glm4_moe.py +10 -1
  192. sglang/srt/models/glm4v.py +4 -2
  193. sglang/srt/models/gpt_oss.py +5 -18
  194. sglang/srt/models/internvl.py +28 -0
  195. sglang/srt/models/llama4.py +9 -0
  196. sglang/srt/models/llama_eagle3.py +17 -0
  197. sglang/srt/models/longcat_flash.py +1026 -0
  198. sglang/srt/models/longcat_flash_nextn.py +699 -0
  199. sglang/srt/models/minicpmv.py +165 -3
  200. sglang/srt/models/mllama4.py +25 -0
  201. sglang/srt/models/opt.py +637 -0
  202. sglang/srt/models/qwen2.py +33 -3
  203. sglang/srt/models/qwen2_5_vl.py +90 -42
  204. sglang/srt/models/qwen2_moe.py +79 -14
  205. sglang/srt/models/qwen3.py +8 -2
  206. sglang/srt/models/qwen3_moe.py +39 -8
  207. sglang/srt/models/qwen3_next.py +1039 -0
  208. sglang/srt/models/qwen3_next_mtp.py +109 -0
  209. sglang/srt/models/torch_native_llama.py +1 -1
  210. sglang/srt/models/transformers.py +1 -1
  211. sglang/srt/multimodal/processors/base_processor.py +4 -2
  212. sglang/srt/multimodal/processors/glm4v.py +9 -9
  213. sglang/srt/multimodal/processors/internvl.py +141 -129
  214. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  215. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  216. sglang/srt/sampling/sampling_batch_info.py +18 -15
  217. sglang/srt/server_args.py +297 -79
  218. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  219. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  220. sglang/srt/speculative/eagle_worker.py +216 -120
  221. sglang/srt/speculative/spec_info.py +5 -0
  222. sglang/srt/speculative/standalone_worker.py +109 -0
  223. sglang/srt/utils.py +37 -2
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  226. sglang/test/few_shot_gsm8k.py +1 -0
  227. sglang/test/runners.py +4 -0
  228. sglang/test/test_cutlass_moe.py +24 -6
  229. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  230. sglang/test/test_disaggregation_utils.py +66 -0
  231. sglang/test/test_utils.py +25 -1
  232. sglang/utils.py +5 -0
  233. sglang/version.py +1 -1
  234. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
  235. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
  236. sglang/srt/disaggregation/launch_lb.py +0 -131
  237. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  238. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  239. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  240. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  241. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  242. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  243. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  244. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  245. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,326 @@
1
+ # Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/layernorm_gated.py
2
+ # Copyright (c) 2024, Tri Dao.
3
+ # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
4
+ # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
5
+ # This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
6
+ # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
7
+
8
+ import math
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import triton
13
+ import triton.language as tl
14
+ from einops import rearrange
15
+
16
+
17
+ def rms_norm_ref(
18
+ x,
19
+ weight,
20
+ bias,
21
+ z=None,
22
+ eps=1e-6,
23
+ group_size=None,
24
+ norm_before_gate=True,
25
+ upcast=True,
26
+ ):
27
+ dtype = x.dtype
28
+ N = x.shape[-1]
29
+ weight = weight.float()
30
+ bias = bias.float() if bias is not None else None
31
+ if upcast:
32
+ x = x.float()
33
+ z = z.float() if z is not None else z
34
+ if z is not None and not norm_before_gate:
35
+ x = x * F.silu(z)
36
+ if group_size is None:
37
+ rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
38
+ out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
39
+ else:
40
+ x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
41
+ rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
42
+ out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
43
+ if bias is not None:
44
+ out = out + bias
45
+ if z is not None and norm_before_gate:
46
+ out *= F.silu(z)
47
+ return out.to(dtype)
48
+
49
+
50
+ @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
51
+ @triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
52
+ @triton.jit
53
+ def _layer_norm_fwd_1pass_kernel(
54
+ X, # pointer to the input
55
+ Y, # pointer to the output
56
+ W, # pointer to the weights
57
+ B, # pointer to the biases
58
+ Z, # pointer to the other branch
59
+ Mean, # pointer to the mean
60
+ Rstd, # pointer to the 1/std
61
+ stride_x_row, # how much to increase the pointer when moving by 1 row
62
+ stride_y_row,
63
+ stride_z_row,
64
+ M, # number of rows in X
65
+ N, # number of columns in X
66
+ eps, # epsilon to avoid division by zero
67
+ BLOCK_N: tl.constexpr,
68
+ HAS_BIAS: tl.constexpr,
69
+ HAS_Z: tl.constexpr,
70
+ NORM_BEFORE_GATE: tl.constexpr,
71
+ IS_RMS_NORM: tl.constexpr,
72
+ ):
73
+ # Map the program id to the row of X and Y it should compute.
74
+ row = tl.program_id(0)
75
+ group = tl.program_id(1)
76
+ X += row * stride_x_row + group * N
77
+ Y += row * stride_y_row + group * N
78
+ if HAS_Z:
79
+ Z += row * stride_z_row + group * N
80
+ if not IS_RMS_NORM:
81
+ Mean += group * M
82
+ Rstd += group * M
83
+ W += group * N
84
+ if HAS_BIAS:
85
+ B += group * N
86
+ # Compute mean and variance
87
+ cols = tl.arange(0, BLOCK_N)
88
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
89
+ if HAS_Z and not NORM_BEFORE_GATE:
90
+ z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
91
+ x *= z * tl.sigmoid(z)
92
+ if not IS_RMS_NORM:
93
+ mean = tl.sum(x, axis=0) / N
94
+ tl.store(Mean + row, mean)
95
+ xbar = tl.where(cols < N, x - mean, 0.0)
96
+ var = tl.sum(xbar * xbar, axis=0) / N
97
+ else:
98
+ xbar = tl.where(cols < N, x, 0.0)
99
+ var = tl.sum(xbar * xbar, axis=0) / N
100
+ rstd = 1 / tl.sqrt(var + eps)
101
+ tl.store(Rstd + row, rstd)
102
+ # Normalize and apply linear transformation
103
+ mask = cols < N
104
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
105
+ if HAS_BIAS:
106
+ b = tl.load(B + cols, mask=mask).to(tl.float32)
107
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
108
+ y = x_hat * w + b if HAS_BIAS else x_hat * w
109
+ if HAS_Z and NORM_BEFORE_GATE:
110
+ z = tl.load(Z + cols, mask=mask).to(tl.float32)
111
+ y *= z * tl.sigmoid(z)
112
+ # Write output
113
+ tl.store(Y + cols, y, mask=mask)
114
+
115
+
116
+ def _layer_norm_fwd(
117
+ x,
118
+ weight,
119
+ bias,
120
+ eps,
121
+ z=None,
122
+ out=None,
123
+ group_size=None,
124
+ norm_before_gate=True,
125
+ is_rms_norm=False,
126
+ ):
127
+ M, N = x.shape
128
+ if group_size is None:
129
+ group_size = N
130
+ assert N % group_size == 0
131
+ ngroups = N // group_size
132
+ assert x.stride(-1) == 1
133
+ if z is not None:
134
+ assert z.stride(-1) == 1
135
+ assert z.shape == (M, N)
136
+ assert weight.shape == (N,)
137
+ assert weight.stride(-1) == 1
138
+ if bias is not None:
139
+ assert bias.stride(-1) == 1
140
+ assert bias.shape == (N,)
141
+ # allocate output
142
+ if out is not None:
143
+ assert out.shape == x.shape
144
+ else:
145
+ out = torch.empty_like(x)
146
+ assert out.stride(-1) == 1
147
+ mean = (
148
+ torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
149
+ if not is_rms_norm
150
+ else None
151
+ )
152
+ rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
153
+ # Less than 64KB per feature: enqueue fused kernel
154
+ MAX_FUSED_SIZE = 65536 // x.element_size()
155
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
156
+ if group_size > BLOCK_N:
157
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
158
+ # heuristics for number of warps
159
+ num_warps = min(max(BLOCK_N // 256, 1), 8)
160
+ grid = (M, ngroups)
161
+ with torch.cuda.device(x.device.index):
162
+ _layer_norm_fwd_1pass_kernel[grid](
163
+ x,
164
+ out,
165
+ weight,
166
+ bias,
167
+ z,
168
+ mean,
169
+ rstd,
170
+ x.stride(0),
171
+ out.stride(0),
172
+ z.stride(0) if z is not None else 0,
173
+ M,
174
+ group_size,
175
+ eps,
176
+ BLOCK_N=BLOCK_N,
177
+ NORM_BEFORE_GATE=norm_before_gate,
178
+ IS_RMS_NORM=is_rms_norm,
179
+ num_warps=num_warps,
180
+ )
181
+ return out, mean, rstd
182
+
183
+
184
+ class LayerNormFn(torch.autograd.Function):
185
+
186
+ @staticmethod
187
+ def forward(
188
+ ctx,
189
+ x,
190
+ weight,
191
+ bias,
192
+ z=None,
193
+ eps=1e-6,
194
+ group_size=None,
195
+ norm_before_gate=True,
196
+ is_rms_norm=False,
197
+ ):
198
+ """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
199
+
200
+ x_shape_og = x.shape
201
+ # reshape input data into 2D tensor
202
+ x = x.reshape(-1, x.shape[-1])
203
+ if x.stride(-1) != 1:
204
+ x = x.contiguous()
205
+ if z is not None:
206
+ assert z.shape == x_shape_og
207
+ z = z.reshape(-1, z.shape[-1])
208
+ if z.stride(-1) != 1:
209
+ z = z.contiguous()
210
+ weight = weight.contiguous()
211
+ if bias is not None:
212
+ bias = bias.contiguous()
213
+ y, mean, rstd = _layer_norm_fwd(
214
+ x,
215
+ weight,
216
+ bias,
217
+ eps,
218
+ z=z,
219
+ group_size=group_size,
220
+ norm_before_gate=norm_before_gate,
221
+ is_rms_norm=is_rms_norm,
222
+ )
223
+ return y.reshape(x_shape_og)
224
+
225
+
226
+ def layernorm_fn(
227
+ x,
228
+ weight,
229
+ bias,
230
+ z=None,
231
+ eps=1e-6,
232
+ group_size=None,
233
+ norm_before_gate=True,
234
+ is_rms_norm=False,
235
+ ):
236
+ return LayerNormFn.apply(
237
+ x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm
238
+ )
239
+
240
+
241
+ def rmsnorm_fn(
242
+ x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
243
+ ):
244
+ return LayerNormFn.apply(
245
+ x, weight, bias, z, eps, group_size, norm_before_gate, True
246
+ )
247
+
248
+
249
+ class LayerNorm(torch.nn.Module):
250
+
251
+ def __init__(
252
+ self,
253
+ hidden_size,
254
+ eps=1e-5,
255
+ group_size=None,
256
+ norm_before_gate=True,
257
+ device=None,
258
+ dtype=None,
259
+ ):
260
+ """If group_size is not None, we do GroupNorm with each group having group_size elements.
261
+ group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
262
+ """
263
+
264
+ factory_kwargs = {"device": device, "dtype": dtype}
265
+ super().__init__()
266
+ self.eps = eps
267
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
268
+ self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
269
+ self.group_size = group_size
270
+ self.norm_before_gate = norm_before_gate
271
+ self.reset_parameters()
272
+
273
+ def reset_parameters(self):
274
+ torch.nn.init.ones_(self.weight)
275
+ torch.nn.init.zeros_(self.bias)
276
+
277
+ def forward(self, x, z=None):
278
+ """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
279
+ return layernorm_fn(
280
+ x,
281
+ self.weight,
282
+ self.bias,
283
+ z=z,
284
+ group_size=self.group_size,
285
+ eps=self.eps,
286
+ norm_before_gate=self.norm_before_gate,
287
+ )
288
+
289
+
290
+ class RMSNorm(torch.nn.Module):
291
+
292
+ def __init__(
293
+ self,
294
+ hidden_size,
295
+ eps=1e-5,
296
+ group_size=None,
297
+ norm_before_gate=True,
298
+ device=None,
299
+ dtype=None,
300
+ ):
301
+ """If group_size is not None, we do GroupNorm with each group having group_size elements.
302
+ group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
303
+ """
304
+ factory_kwargs = {"device": device, "dtype": dtype}
305
+ super().__init__()
306
+ self.eps = eps
307
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
308
+ self.register_parameter("bias", None)
309
+ self.group_size = group_size
310
+ self.norm_before_gate = norm_before_gate
311
+ self.reset_parameters()
312
+
313
+ def reset_parameters(self):
314
+ torch.nn.init.ones_(self.weight)
315
+
316
+ def forward(self, x, z=None):
317
+ """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
318
+ return rmsnorm_fn(
319
+ x,
320
+ self.weight,
321
+ self.bias,
322
+ z=z,
323
+ eps=self.eps,
324
+ group_size=self.group_size,
325
+ norm_before_gate=self.norm_before_gate,
326
+ )
@@ -0,0 +1,66 @@
1
+ # Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/op.py
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
4
+
5
+ import os
6
+
7
+ import triton
8
+ import triton.language as tl
9
+ import triton.language.extra.libdevice as tldevice
10
+
11
+ from sglang.srt.layers.attention.fla.utils import is_gather_supported
12
+
13
+ if os.environ.get("FLA_USE_FAST_OPS", "0") == "1":
14
+ exp = tldevice.fast_expf
15
+ exp2 = tldevice.exp2
16
+ log = tldevice.fast_logf
17
+ log2 = tldevice.fast_log2f
18
+ else:
19
+ exp = tl.exp
20
+ exp2 = tl.math.exp2
21
+ log = tl.log
22
+ log2 = tl.log2
23
+
24
+
25
+ @triton.jit
26
+ def safe_exp(x):
27
+ return exp(tl.where(x <= 0, x, float("-inf")))
28
+
29
+
30
+ if not is_gather_supported:
31
+
32
+ @triton.jit
33
+ def gather(src, index, axis, _builder=None):
34
+ """
35
+ Gather operation that works when tl.gather is not supported.
36
+ This is a fallback implementation that returns None.
37
+ Just to make triton compiler happy.
38
+ """
39
+ return None
40
+
41
+ else:
42
+ gather = tl.gather
43
+
44
+
45
+ if hasattr(triton.language, "_experimental_make_tensor_descriptor"):
46
+ # For Triton 3.3.x
47
+ make_tensor_descriptor = triton.language._experimental_make_tensor_descriptor
48
+ elif hasattr(triton.language, "make_tensor_descriptor"):
49
+ # For Triton 3.4.x and later
50
+ make_tensor_descriptor = triton.language.make_tensor_descriptor
51
+ else:
52
+ """
53
+ Fallback implementation when TMA is not supported.
54
+ Returns None to indicate TMA descriptors are unavailable.
55
+ Just make triton compiler happy.
56
+ """
57
+
58
+ @triton.jit
59
+ def make_tensor_descriptor(
60
+ base,
61
+ shape,
62
+ strides,
63
+ block_shape,
64
+ _builder=None,
65
+ ):
66
+ return None