sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,211 @@
1
+ # Copyright 2025 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ # Adapted from https://github.com/vllm-project/vllm/blob/2c58742dff8613a3bd7496f2008ce927e18d38d1/vllm/model_executor/layers/mamba/mamba2_metadata.py
15
+
16
+
17
+ import math
18
+ from dataclasses import dataclass
19
+
20
+ import torch
21
+
22
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
23
+
24
+
25
+ @dataclass(kw_only=True)
26
+ class ForwardMetadata:
27
+ query_start_loc: torch.Tensor
28
+ mamba_cache_indices: torch.Tensor
29
+
30
+
31
+ @dataclass(kw_only=True)
32
+ class Mamba2Metadata(ForwardMetadata):
33
+ """stable metadata across all mamba2 layers in the forward pass"""
34
+
35
+ num_prefills: int
36
+ num_prefill_tokens: int
37
+ num_decodes: int
38
+
39
+ @dataclass(kw_only=True, frozen=True)
40
+ class MixedMetadata:
41
+ has_initial_states: torch.Tensor
42
+ prep_initial_states: bool
43
+
44
+ chunk_size: int
45
+ seq_idx: torch.Tensor
46
+ chunk_indices: torch.Tensor
47
+ chunk_offsets: torch.Tensor
48
+
49
+ extend_seq_lens_cpu: list[int]
50
+
51
+ mixed_metadata: MixedMetadata | None = None
52
+ """`mixed_metadata` is used for extend/mixed requests"""
53
+
54
+ @staticmethod
55
+ def _query_start_loc_to_chunk_indices_offsets(
56
+ query_start_loc: torch.Tensor, chunk_size: int, total_seqlens: int
57
+ ) -> tuple[torch.Tensor, torch.Tensor]:
58
+ """
59
+ Args:
60
+ query_start_loc (torch.Tensor): 1D tensor of cumulative sequence
61
+ lengths, shape (num_seqs + 1,).
62
+ The first element should be 0. Each entry represents the starting
63
+ index of a sequence in the flattened token array.
64
+ chunk_size (int): The size of each physical mamba chunk
65
+ (number of tokens per chunk).
66
+ total_seqlens (int): The total number of tokens in the batch.
67
+
68
+ Returns:
69
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
70
+ - chunk_indices (torch.Tensor): 1D tensor of indices
71
+ indicating the physical chunk for each logical chunk.
72
+ - chunk_offsets (torch.Tensor): 1D tensor of offsets
73
+ indicating the starting index of each logical chunk within
74
+ its physical chunk.
75
+
76
+ This function computes the chunk indices and offsets for the given
77
+ query_start_loc and chunk_size. Both are tensors of integers with length N,
78
+ where N is the number of logical (pseudo) chunks.
79
+ A logical chunk is a sequence of tokens that are all part of the same
80
+ sequence and are all in the same physical mamba chunk.
81
+ In other words, a logical chunk changes every time we cross a sequence
82
+ boundary or a physical mamba chunk boundary.
83
+ Logical chunks are needed to handle batched requests with initial states
84
+ (see _state_passing_fwd and _chunk_scan_fwd).
85
+ The chunk_indices tensor contains the index of the physical chunk for each
86
+ logical chunk.
87
+ The chunk_offsets tensor contains the offset (AKA starting index) of the
88
+ logical chunk in the physical chunk.
89
+
90
+ Example:
91
+ query_start_loc = [0, 5, 10]
92
+ chunk_size = 8
93
+ total_seqlens = 10
94
+ -> chunk_indices = [0, 0, 1]
95
+ -> chunk_offsets = [0, 5, 0]
96
+
97
+ In this example, we have 2 sequences, each with 5 tokens. The physical
98
+ chunk size is 8 tokens.
99
+ We have three logical chunks:
100
+ - the first logical chunk starts at token 0 in the first physical chunk
101
+ and contains all 5 tokens from the first sequence
102
+ - the second logical chunk starts at token 5 in the first physical chunk
103
+ and contains first 3 tokens from the second sequence
104
+ - the third logical chunk starts at token 0 in the second physical chunk
105
+ and contains the remaining 2 tokens from the second sequence
106
+ """
107
+
108
+ cu_seqlens = query_start_loc[1:] # remove prepended 0
109
+
110
+ # outputs will have length expansion of chunks that do not divide
111
+ # chunk_size
112
+ N = (
113
+ math.ceil(total_seqlens / chunk_size)
114
+ + (cu_seqlens[:-1] % chunk_size > 0).sum()
115
+ )
116
+ chunk_indices = torch.arange(N, dtype=torch.int, device=query_start_loc.device)
117
+ chunk_offsets = torch.zeros(
118
+ (N,), dtype=torch.int, device=query_start_loc.device
119
+ )
120
+
121
+ p = 0 # num of insertions
122
+ for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
123
+
124
+ # if does not divide chunk_size, then there is one chunk insertion
125
+ p += s % chunk_size > 0
126
+
127
+ # get the dimensions
128
+ # - the + 1 for _e is to shift the boundary by one chunk
129
+ # - this shifting is not needed if chunk_size divides e
130
+ _s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size > 0)
131
+
132
+ # adjust indices and offsets
133
+ chunk_indices[_s:_e] -= p
134
+ chunk_offsets[_s] = s % chunk_size
135
+
136
+ return chunk_indices, chunk_offsets
137
+
138
+ @staticmethod
139
+ def prepare_decode(
140
+ query_start_loc: torch.Tensor,
141
+ mamba_cache_indices: torch.Tensor,
142
+ seq_lens: torch.Tensor,
143
+ ) -> "Mamba2Metadata":
144
+ """This path is run during CUDA graph capture, i.e. decode only, so `num_prefills` is 0"""
145
+ return Mamba2Metadata(
146
+ query_start_loc=query_start_loc,
147
+ mamba_cache_indices=mamba_cache_indices,
148
+ num_decodes=len(seq_lens),
149
+ num_prefills=0,
150
+ num_prefill_tokens=0,
151
+ )
152
+
153
+ @classmethod
154
+ def prepare_mixed(
155
+ cls,
156
+ query_start_loc: torch.Tensor,
157
+ mamba_cache_indices: torch.Tensor,
158
+ chunk_size: int,
159
+ forward_batch: ForwardBatch,
160
+ ) -> "Mamba2Metadata":
161
+ """This path cannot run with CUDA graph, as it contains extend requests."""
162
+ if forward_batch.extend_num_tokens is None:
163
+ return cls.prepare_decode(
164
+ query_start_loc, mamba_cache_indices, forward_batch.seq_lens
165
+ )
166
+ num_prefills = len(forward_batch.extend_seq_lens)
167
+ num_prefill_tokens = forward_batch.extend_num_tokens
168
+ num_decodes = len(forward_batch.seq_lens) - num_prefills
169
+ context_lens_tensor = forward_batch.extend_prefix_lens
170
+ assert context_lens_tensor is not None
171
+ # precompute flag to avoid device syncs later
172
+ has_initial_states = context_lens_tensor > 0
173
+ prep_initial_states = torch.any(has_initial_states[:num_prefills]).item()
174
+
175
+ query_start_loc = query_start_loc[: num_prefills + 1]
176
+ seq_idx = torch.repeat_interleave(
177
+ torch.arange(
178
+ num_prefills, dtype=torch.int32, device=query_start_loc.device
179
+ ),
180
+ query_start_loc.diff(),
181
+ output_size=num_prefill_tokens,
182
+ )
183
+ seq_idx.unsqueeze_(0)
184
+
185
+ # We compute metadata for chunked prefill once at the top level model
186
+ # forward and reuse them in mamba layers. If not needed, they will be
187
+ # ignored inside mamba kernels.
188
+ chunk_offsets, chunk_indices = None, None
189
+ if prep_initial_states:
190
+ chunk_indices, chunk_offsets = (
191
+ cls._query_start_loc_to_chunk_indices_offsets(
192
+ query_start_loc, chunk_size, num_prefill_tokens
193
+ )
194
+ )
195
+
196
+ return Mamba2Metadata(
197
+ query_start_loc=query_start_loc,
198
+ mamba_cache_indices=mamba_cache_indices,
199
+ num_prefills=num_prefills,
200
+ num_prefill_tokens=num_prefill_tokens,
201
+ num_decodes=num_decodes,
202
+ mixed_metadata=cls.MixedMetadata(
203
+ has_initial_states=has_initial_states,
204
+ prep_initial_states=prep_initial_states,
205
+ chunk_size=chunk_size,
206
+ seq_idx=seq_idx,
207
+ chunk_indices=chunk_indices,
208
+ chunk_offsets=chunk_offsets,
209
+ extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
210
+ ),
211
+ )
@@ -0,0 +1,120 @@
1
+ from typing import Union
2
+
3
+ import torch
4
+
5
+ from sglang.srt.custom_op import CustomOp
6
+ from sglang.srt.distributed.communication_op import (
7
+ tensor_model_parallel_all_gather,
8
+ tensor_model_parallel_all_reduce,
9
+ )
10
+ from sglang.srt.distributed.parallel_state import (
11
+ get_tensor_model_parallel_rank,
12
+ get_tensor_model_parallel_world_size,
13
+ )
14
+ from sglang.srt.layers.attention.fla.layernorm_gated import rms_norm_gated
15
+ from sglang.srt.model_loader.weight_utils import sharded_weight_loader
16
+ from sglang.srt.utils.common import set_weight_attrs
17
+
18
+
19
+ class Mixer2RMSNormGated(CustomOp):
20
+ def __init__(
21
+ self,
22
+ full_hidden_size: int,
23
+ full_n_groups: int,
24
+ use_rms_norm: bool = True,
25
+ eps: float = 1e-6,
26
+ ):
27
+ super().__init__()
28
+ self.tp_size = get_tensor_model_parallel_world_size()
29
+ self.tp_rank = get_tensor_model_parallel_rank()
30
+ self.full_hidden_size = full_hidden_size
31
+ self.group_size = full_hidden_size // full_n_groups
32
+ self.per_rank_hidden_size = full_hidden_size // self.tp_size
33
+ self.n_groups = full_hidden_size // self.group_size
34
+
35
+ self.variance_epsilon = eps
36
+ self.use_rms_norm = use_rms_norm
37
+ if self.use_rms_norm:
38
+ # Register norm weight only if we're actually applying RMSNorm
39
+ self.weight = torch.nn.Parameter(torch.ones(self.per_rank_hidden_size))
40
+ set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader(0)})
41
+ else:
42
+ # Avoid checkpoint mismatch by skipping unused parameter
43
+ self.register_parameter("weight", None)
44
+ assert (
45
+ self.full_hidden_size % self.tp_size == 0
46
+ ), "Tensor parallel world size must divide hidden size."
47
+
48
+ def forward_native(
49
+ self,
50
+ x: torch.Tensor,
51
+ gate: torch.Tensor,
52
+ ):
53
+ # Three tensor-parallel cases:
54
+ # 1. n_groups is 1
55
+ # In this case we parallelize along the reduction dim.
56
+ # Each rank computes a local sum of squares followed by AllReduce
57
+ # 2. tp_size divides n_groups
58
+ # Each rank only reduces within its local group(s).
59
+ # No collective ops necessary.
60
+ # 3. The general case can be pretty complicated so we AllGather
61
+ # the input and then redundantly compute the RMSNorm.
62
+ input_dtype = x.dtype
63
+ x = x * torch.nn.functional.silu(gate.to(torch.float32))
64
+ if not self.use_rms_norm:
65
+ return x.to(input_dtype)
66
+
67
+ if self.n_groups == 1:
68
+ if self.tp_size > 1:
69
+ # Compute local sum and then reduce to obtain global sum
70
+ local_sums = x.pow(2).sum(dim=-1, keepdim=True)
71
+ global_sums = tensor_model_parallel_all_reduce(local_sums)
72
+ # Calculate the variance
73
+ count = self.tp_size * x.shape[-1]
74
+ variance = global_sums / count
75
+
76
+ else:
77
+ variance = x.pow(2).mean(-1, keepdim=True)
78
+ x = x * torch.rsqrt(variance + self.variance_epsilon)
79
+ else:
80
+ redundant_tp: bool = self.n_groups % self.tp_size != 0
81
+ if redundant_tp:
82
+ # To handle the general case, redundantly apply the variance
83
+ x = tensor_model_parallel_all_gather(x, -1)
84
+
85
+ *prefix_dims, hidden_dim = x.shape
86
+ group_count = hidden_dim // self.group_size
87
+ x_grouped = x.view(*prefix_dims, group_count, self.group_size)
88
+ variance = x_grouped.pow(2).mean(-1, keepdim=True)
89
+ x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon)
90
+ x = x_grouped.view(*prefix_dims, hidden_dim)
91
+
92
+ if redundant_tp:
93
+ start = self.per_rank_hidden_size * self.tp_rank
94
+ end = start + self.per_rank_hidden_size
95
+ x = x[..., start:end]
96
+
97
+ return self.weight * x.to(input_dtype)
98
+
99
+ def forward_cuda(
100
+ self,
101
+ x: torch.Tensor,
102
+ gate: torch.Tensor,
103
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
104
+ input_dtype = x.dtype
105
+ if not self.use_rms_norm:
106
+ # Keep gate in float32 for numerical stability during silu
107
+ return x * torch.nn.functional.silu(gate.to(torch.float32)).to(input_dtype)
108
+
109
+ if ((self.n_groups % self.tp_size) != 0) or self.n_groups != 1:
110
+ return self.forward_native(x, gate)
111
+
112
+ return rms_norm_gated(
113
+ x=x,
114
+ weight=self.weight.data,
115
+ bias=None,
116
+ z=gate,
117
+ eps=self.variance_epsilon,
118
+ norm_before_gate=False,
119
+ is_rms_norm=True,
120
+ )
@@ -0,0 +1,2 @@
1
+ from .mamba_ssm import selective_state_update
2
+ from .ssd_combined import mamba_chunk_scan_combined
@@ -0,0 +1,172 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ # Copyright (c) 2024, Tri Dao.
4
+ # Adapted from https://github.com/state-spaces/mamba/blob/60dadf2e0ee730ac337035d5533de10bc26e4847/mamba_ssm/ops/triton/layernorm_gated.py
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+
11
+ @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
12
+ @triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
13
+ @triton.jit
14
+ def _layer_norm_fwd_1pass_kernel(
15
+ X, # pointer to the input
16
+ Y, # pointer to the output
17
+ W, # pointer to the weights
18
+ B, # pointer to the biases
19
+ Z, # pointer to the other branch
20
+ Mean, # pointer to the mean
21
+ Rstd, # pointer to the 1/std
22
+ stride_x_row: tl.int64,
23
+ stride_y_row: tl.int64,
24
+ stride_z_row: tl.int64,
25
+ M: tl.int64, # number of rows in X
26
+ N: tl.int64, # number of columns in X
27
+ eps, # epsilon to avoid division by zero
28
+ BLOCK_N: tl.constexpr,
29
+ HAS_BIAS: tl.constexpr,
30
+ HAS_Z: tl.constexpr,
31
+ NORM_BEFORE_GATE: tl.constexpr,
32
+ IS_RMS_NORM: tl.constexpr,
33
+ ):
34
+ # Map the program id to the row of X and Y it should compute.
35
+ row = tl.program_id(0)
36
+ group = tl.program_id(1)
37
+ X += row * stride_x_row + group * N
38
+ Y += row * stride_y_row + group * N
39
+ if HAS_Z:
40
+ Z += row * stride_z_row + group * N
41
+ if not IS_RMS_NORM:
42
+ Mean += group * M
43
+ Rstd += group * M
44
+ W += group * N
45
+ if HAS_BIAS:
46
+ B += group * N
47
+ # Compute mean and variance
48
+ cols = tl.arange(0, BLOCK_N)
49
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
50
+ if HAS_Z and not NORM_BEFORE_GATE:
51
+ z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
52
+ x *= z * tl.sigmoid(z)
53
+ if not IS_RMS_NORM:
54
+ mean = tl.sum(x, axis=0) / N
55
+ tl.store(Mean + row, mean)
56
+ xbar = tl.where(cols < N, x - mean, 0.0)
57
+ var = tl.sum(xbar * xbar, axis=0) / N
58
+ else:
59
+ xbar = tl.where(cols < N, x, 0.0)
60
+ var = tl.sum(xbar * xbar, axis=0) / N
61
+ rstd = 1 / tl.sqrt(var + eps)
62
+ tl.store(Rstd + row, rstd)
63
+ # Normalize and apply linear transformation
64
+ mask = cols < N
65
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
66
+ if HAS_BIAS:
67
+ b = tl.load(B + cols, mask=mask).to(tl.float32)
68
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
69
+ y = x_hat * w + b if HAS_BIAS else x_hat * w
70
+ if HAS_Z and NORM_BEFORE_GATE:
71
+ z = tl.load(Z + cols, mask=mask).to(tl.float32)
72
+ y *= z * tl.sigmoid(z)
73
+ # Write output
74
+ tl.store(Y + cols, y, mask=mask)
75
+
76
+
77
+ def _layer_norm_fwd(
78
+ x,
79
+ weight,
80
+ bias,
81
+ eps,
82
+ z=None,
83
+ out=None,
84
+ group_size=None,
85
+ norm_before_gate=True,
86
+ is_rms_norm=False,
87
+ ):
88
+ M, N = x.shape
89
+ if group_size is None:
90
+ group_size = N
91
+ assert N % group_size == 0
92
+ ngroups = N // group_size
93
+ assert x.stride(-1) == 1
94
+ if z is not None:
95
+ assert z.stride(-1) == 1
96
+ assert z.shape == (M, N)
97
+ assert weight.shape == (N,)
98
+ assert weight.stride(-1) == 1
99
+ if bias is not None:
100
+ assert bias.stride(-1) == 1
101
+ assert bias.shape == (N,)
102
+ # allocate output
103
+ if out is not None:
104
+ assert out.shape == x.shape
105
+ else:
106
+ out = torch.empty_like(x)
107
+ assert out.stride(-1) == 1
108
+ mean = (
109
+ torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
110
+ if not is_rms_norm
111
+ else None
112
+ )
113
+ rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
114
+ # Less than 64KB per feature: enqueue fused kernel
115
+ MAX_FUSED_SIZE = 65536 // x.element_size()
116
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
117
+ if group_size > BLOCK_N:
118
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
119
+ # heuristics for number of warps
120
+ num_warps = min(max(BLOCK_N // 256, 1), 8)
121
+ grid = (M, ngroups)
122
+ with torch.cuda.device(x.device.index):
123
+ _layer_norm_fwd_1pass_kernel[grid](
124
+ x,
125
+ out,
126
+ weight,
127
+ bias,
128
+ z,
129
+ mean,
130
+ rstd,
131
+ x.stride(0),
132
+ out.stride(0),
133
+ z.stride(0) if z is not None else 0,
134
+ M,
135
+ group_size,
136
+ eps,
137
+ BLOCK_N=BLOCK_N,
138
+ NORM_BEFORE_GATE=norm_before_gate,
139
+ IS_RMS_NORM=is_rms_norm,
140
+ num_warps=num_warps,
141
+ )
142
+ return out, mean, rstd
143
+
144
+
145
+ def rms_norm_gated(
146
+ x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
147
+ ):
148
+ x_shape_og = x.shape
149
+ # reshape input data into 2D tensor
150
+ x = x.reshape(-1, x.shape[-1])
151
+ if x.stride(-1) != 1:
152
+ x = x.contiguous()
153
+ if z is not None:
154
+ assert z.shape == x_shape_og
155
+ z = z.reshape(-1, z.shape[-1])
156
+ if z.stride(-1) != 1:
157
+ z = z.contiguous()
158
+ weight = weight.contiguous()
159
+ if bias is not None:
160
+ bias = bias.contiguous()
161
+ y, _, _ = _layer_norm_fwd(
162
+ x,
163
+ weight,
164
+ bias,
165
+ eps,
166
+ z=z,
167
+ group_size=group_size,
168
+ norm_before_gate=norm_before_gate,
169
+ is_rms_norm=True,
170
+ )
171
+
172
+ return y.reshape(x_shape_og)