sglang 0.5.2rc1__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (265) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/lang/interpreter.py +1 -1
  4. sglang/srt/configs/__init__.py +4 -0
  5. sglang/srt/configs/device_config.py +3 -1
  6. sglang/srt/configs/dots_vlm.py +139 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/load_config.py +1 -0
  9. sglang/srt/configs/model_config.py +50 -6
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +8 -1
  12. sglang/srt/connector/remote_instance.py +82 -0
  13. sglang/srt/constrained/base_grammar_backend.py +48 -12
  14. sglang/srt/constrained/llguidance_backend.py +0 -1
  15. sglang/srt/constrained/outlines_backend.py +0 -1
  16. sglang/srt/constrained/xgrammar_backend.py +28 -9
  17. sglang/srt/custom_op.py +11 -1
  18. sglang/srt/debug_utils/dump_comparator.py +81 -44
  19. sglang/srt/debug_utils/dump_loader.py +97 -0
  20. sglang/srt/debug_utils/dumper.py +11 -3
  21. sglang/srt/debug_utils/text_comparator.py +73 -11
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +21 -10
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  26. sglang/srt/disaggregation/fake/conn.py +1 -1
  27. sglang/srt/disaggregation/mini_lb.py +6 -445
  28. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  29. sglang/srt/disaggregation/nixl/conn.py +180 -16
  30. sglang/srt/disaggregation/prefill.py +5 -3
  31. sglang/srt/disaggregation/utils.py +5 -50
  32. sglang/srt/distributed/parallel_state.py +67 -43
  33. sglang/srt/entrypoints/engine.py +38 -17
  34. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  35. sglang/srt/entrypoints/grpc_server.py +680 -0
  36. sglang/srt/entrypoints/http_server.py +88 -53
  37. sglang/srt/entrypoints/openai/protocol.py +7 -4
  38. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  39. sglang/srt/entrypoints/openai/serving_chat.py +39 -19
  40. sglang/srt/entrypoints/openai/serving_completions.py +15 -4
  41. sglang/srt/entrypoints/openai/serving_embedding.py +9 -4
  42. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  43. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  44. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  45. sglang/srt/eplb/eplb_manager.py +2 -2
  46. sglang/srt/eplb/expert_distribution.py +26 -13
  47. sglang/srt/eplb/expert_location.py +8 -3
  48. sglang/srt/eplb/expert_location_updater.py +1 -1
  49. sglang/srt/function_call/base_format_detector.py +3 -6
  50. sglang/srt/function_call/ebnf_composer.py +11 -9
  51. sglang/srt/function_call/function_call_parser.py +6 -0
  52. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  53. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  54. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  55. sglang/srt/grpc/__init__.py +1 -0
  56. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  57. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  58. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  59. sglang/srt/hf_transformers_utils.py +4 -0
  60. sglang/srt/layers/activation.py +142 -9
  61. sglang/srt/layers/attention/aiter_backend.py +93 -68
  62. sglang/srt/layers/attention/ascend_backend.py +11 -4
  63. sglang/srt/layers/attention/fla/chunk.py +242 -0
  64. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  65. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  66. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  67. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  68. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  69. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  70. sglang/srt/layers/attention/fla/index.py +37 -0
  71. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  72. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  73. sglang/srt/layers/attention/fla/op.py +66 -0
  74. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  75. sglang/srt/layers/attention/fla/utils.py +331 -0
  76. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  77. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  78. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  79. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  80. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  81. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  82. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  83. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  84. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  85. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  86. sglang/srt/layers/attention/triton_backend.py +18 -1
  87. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  88. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  89. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  90. sglang/srt/layers/communicator.py +45 -7
  91. sglang/srt/layers/dp_attention.py +30 -1
  92. sglang/srt/layers/layernorm.py +32 -15
  93. sglang/srt/layers/linear.py +34 -3
  94. sglang/srt/layers/logits_processor.py +29 -10
  95. sglang/srt/layers/moe/__init__.py +2 -1
  96. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  97. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  98. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  99. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  100. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  107. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  108. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  113. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  114. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  115. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  116. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  117. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  118. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  119. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  120. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  121. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  122. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  123. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  124. sglang/srt/layers/moe/topk.py +30 -9
  125. sglang/srt/layers/moe/utils.py +12 -7
  126. sglang/srt/layers/quantization/awq.py +19 -7
  127. sglang/srt/layers/quantization/base_config.py +11 -6
  128. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  129. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  130. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  131. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  132. sglang/srt/layers/quantization/fp8.py +76 -47
  133. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  134. sglang/srt/layers/quantization/gptq.py +25 -17
  135. sglang/srt/layers/quantization/modelopt_quant.py +182 -49
  136. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  137. sglang/srt/layers/quantization/mxfp4.py +68 -41
  138. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  139. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  140. sglang/srt/layers/quantization/quark/utils.py +97 -0
  141. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  142. sglang/srt/layers/quantization/unquant.py +135 -47
  143. sglang/srt/layers/quantization/w4afp8.py +30 -17
  144. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  145. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  146. sglang/srt/layers/rocm_linear_utils.py +44 -0
  147. sglang/srt/layers/rotary_embedding.py +0 -18
  148. sglang/srt/layers/sampler.py +162 -18
  149. sglang/srt/lora/backend/base_backend.py +50 -8
  150. sglang/srt/lora/backend/triton_backend.py +90 -2
  151. sglang/srt/lora/layers.py +32 -0
  152. sglang/srt/lora/lora.py +4 -1
  153. sglang/srt/lora/lora_manager.py +35 -112
  154. sglang/srt/lora/mem_pool.py +24 -10
  155. sglang/srt/lora/utils.py +18 -9
  156. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  157. sglang/srt/managers/cache_controller.py +200 -199
  158. sglang/srt/managers/data_parallel_controller.py +105 -35
  159. sglang/srt/managers/detokenizer_manager.py +8 -4
  160. sglang/srt/managers/disagg_service.py +46 -0
  161. sglang/srt/managers/io_struct.py +199 -12
  162. sglang/srt/managers/mm_utils.py +1 -0
  163. sglang/srt/managers/multi_tokenizer_mixin.py +351 -397
  164. sglang/srt/managers/schedule_batch.py +77 -56
  165. sglang/srt/managers/schedule_policy.py +4 -3
  166. sglang/srt/managers/scheduler.py +191 -139
  167. sglang/srt/managers/scheduler_metrics_mixin.py +116 -9
  168. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  169. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  170. sglang/srt/managers/template_manager.py +3 -3
  171. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  172. sglang/srt/managers/tokenizer_manager.py +260 -519
  173. sglang/srt/managers/tp_worker.py +53 -4
  174. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  175. sglang/srt/mem_cache/allocator.py +1 -1
  176. sglang/srt/mem_cache/hicache_storage.py +18 -33
  177. sglang/srt/mem_cache/hiradix_cache.py +108 -48
  178. sglang/srt/mem_cache/memory_pool.py +347 -48
  179. sglang/srt/mem_cache/memory_pool_host.py +121 -57
  180. sglang/srt/mem_cache/radix_cache.py +0 -2
  181. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  182. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  183. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +95 -5
  184. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  185. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  186. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +81 -20
  187. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  188. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  189. sglang/srt/metrics/collector.py +502 -77
  190. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  191. sglang/srt/metrics/utils.py +48 -0
  192. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  193. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  194. sglang/srt/model_executor/forward_batch_info.py +75 -19
  195. sglang/srt/model_executor/model_runner.py +357 -30
  196. sglang/srt/model_loader/__init__.py +9 -3
  197. sglang/srt/model_loader/loader.py +128 -4
  198. sglang/srt/model_loader/weight_utils.py +2 -1
  199. sglang/srt/models/apertus.py +686 -0
  200. sglang/srt/models/bailing_moe.py +798 -218
  201. sglang/srt/models/bailing_moe_nextn.py +168 -0
  202. sglang/srt/models/deepseek_v2.py +346 -48
  203. sglang/srt/models/dots_vlm.py +174 -0
  204. sglang/srt/models/dots_vlm_vit.py +337 -0
  205. sglang/srt/models/ernie4.py +1 -1
  206. sglang/srt/models/gemma3n_mm.py +1 -1
  207. sglang/srt/models/glm4_moe.py +11 -2
  208. sglang/srt/models/glm4v.py +4 -2
  209. sglang/srt/models/glm4v_moe.py +3 -0
  210. sglang/srt/models/gpt_oss.py +1 -1
  211. sglang/srt/models/internvl.py +28 -0
  212. sglang/srt/models/llama4.py +9 -0
  213. sglang/srt/models/llama_eagle3.py +13 -0
  214. sglang/srt/models/longcat_flash.py +2 -2
  215. sglang/srt/models/minicpmv.py +165 -3
  216. sglang/srt/models/mllama4.py +25 -0
  217. sglang/srt/models/opt.py +637 -0
  218. sglang/srt/models/qwen2.py +7 -0
  219. sglang/srt/models/qwen2_5_vl.py +27 -3
  220. sglang/srt/models/qwen2_moe.py +60 -13
  221. sglang/srt/models/qwen3.py +8 -2
  222. sglang/srt/models/qwen3_moe.py +40 -9
  223. sglang/srt/models/qwen3_next.py +1042 -0
  224. sglang/srt/models/qwen3_next_mtp.py +112 -0
  225. sglang/srt/models/step3_vl.py +1 -1
  226. sglang/srt/models/torch_native_llama.py +1 -1
  227. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  228. sglang/srt/multimodal/processors/glm4v.py +9 -9
  229. sglang/srt/multimodal/processors/internvl.py +141 -129
  230. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  231. sglang/srt/offloader.py +27 -3
  232. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  233. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  234. sglang/srt/sampling/sampling_batch_info.py +18 -15
  235. sglang/srt/server_args.py +355 -37
  236. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  237. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  238. sglang/srt/speculative/eagle_utils.py +0 -2
  239. sglang/srt/speculative/eagle_worker.py +197 -112
  240. sglang/srt/speculative/spec_info.py +5 -0
  241. sglang/srt/speculative/standalone_worker.py +109 -0
  242. sglang/srt/tracing/trace.py +552 -0
  243. sglang/srt/utils.py +46 -3
  244. sglang/srt/weight_sync/utils.py +1 -1
  245. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  246. sglang/test/few_shot_gsm8k.py +1 -0
  247. sglang/test/runners.py +4 -0
  248. sglang/test/test_cutlass_moe.py +24 -6
  249. sglang/test/test_disaggregation_utils.py +66 -0
  250. sglang/test/test_fp4_moe.py +370 -1
  251. sglang/test/test_utils.py +28 -1
  252. sglang/utils.py +12 -0
  253. sglang/version.py +1 -1
  254. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  255. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +263 -200
  256. sglang/srt/disaggregation/launch_lb.py +0 -118
  257. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  258. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  259. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  260. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  261. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  262. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  263. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  264. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  265. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -35,6 +35,7 @@ from sglang.srt.utils import (
35
35
  is_cuda,
36
36
  is_hip,
37
37
  is_npu,
38
+ is_xpu,
38
39
  set_weight_attrs,
39
40
  )
40
41
  from sglang.utils import resolve_obj_by_qualname
@@ -44,8 +45,9 @@ _is_npu = is_npu()
44
45
  _is_cpu_amx_available = cpu_has_amx_support()
45
46
  _is_cpu = is_cpu()
46
47
  _is_hip = is_hip()
48
+ _is_xpu = is_xpu()
47
49
 
48
- if _is_cuda:
50
+ if _is_cuda or _is_xpu:
49
51
  from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
50
52
  elif _is_hip:
51
53
  from sgl_kernel import gelu_and_mul, gelu_quick, gelu_tanh_and_mul, silu_and_mul
@@ -70,8 +72,6 @@ class SiluAndMul(CustomOp):
70
72
 
71
73
  def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
72
74
  if _is_cpu_amx_available:
73
- d = x.shape[-1] // 2
74
- output_shape = x.shape[:-1] + (d,)
75
75
  out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)
76
76
  return out
77
77
  else:
@@ -81,17 +81,20 @@ class SiluAndMul(CustomOp):
81
81
  out = torch_npu.npu_swiglu(x)
82
82
  return out
83
83
 
84
+ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
85
+ d = x.shape[-1] // 2
86
+ output_shape = x.shape[:-1] + (d,)
87
+ out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
88
+ silu_and_mul(x, out)
89
+ return out
90
+
84
91
 
85
92
  class GeluAndMul(CustomOp):
86
93
  def __init__(self, approximate="tanh"):
87
94
  super().__init__()
88
95
  self.approximate = approximate
89
96
 
90
- def forward_native(self, x: torch.Tensor) -> torch.Tensor:
91
- d = x.shape[-1] // 2
92
- return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
93
-
94
- def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
97
+ def _forward_impl(self, x: torch.Tensor) -> torch.Tensor:
95
98
  d = x.shape[-1] // 2
96
99
  output_shape = x.shape[:-1] + (d,)
97
100
  out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
@@ -103,6 +106,24 @@ class GeluAndMul(CustomOp):
103
106
  raise RuntimeError("GeluAndMul only support tanh or none")
104
107
  return out
105
108
 
109
+ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
110
+ d = x.shape[-1] // 2
111
+ return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
112
+
113
+ def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
114
+ if _is_cpu_amx_available and self.approximate == "tanh":
115
+ return torch.ops.sgl_kernel.gelu_tanh_and_mul_cpu(x)
116
+ elif _is_cpu_amx_available and self.approximate == "none":
117
+ return torch.ops.sgl_kernel.gelu_and_mul_cpu(x)
118
+ else:
119
+ return self.forward_native(x)
120
+
121
+ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
122
+ return self._forward_impl(x)
123
+
124
+ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
125
+ return self._forward_impl(x)
126
+
106
127
  def forward_npu(self, x: torch.Tensor) -> torch.Tensor:
107
128
  y_npu, gelu_npu = torch_npu.npu_geglu(
108
129
  x,
@@ -150,6 +171,115 @@ class QuickGELU(CustomOp):
150
171
  return torch_npu.npu_fast_gelu(x)
151
172
 
152
173
 
174
+ class XIELU(CustomOp):
175
+ """
176
+ Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
177
+ If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA
178
+ Otherwise, we emit a single warning and use xIELU Python
179
+ """
180
+
181
+ def __init__(
182
+ self,
183
+ alpha_p_init: float = 0.8,
184
+ alpha_n_init: float = 0.8,
185
+ beta: float = 0.5,
186
+ eps: float = -1e-6,
187
+ dtype: torch.dtype = torch.bfloat16,
188
+ with_vector_loads: bool = False,
189
+ ):
190
+ super().__init__()
191
+ self.alpha_p = nn.Parameter(
192
+ torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) - 1).unsqueeze(
193
+ 0
194
+ )
195
+ )
196
+ self.alpha_n = nn.Parameter(
197
+ torch.log(
198
+ torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) - 1
199
+ ).unsqueeze(0)
200
+ )
201
+ self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
202
+ self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
203
+ self.with_vector_loads = with_vector_loads
204
+ # Temporary until xIELU CUDA fully implemented
205
+ self._beta_scalar = float(self.beta.detach().cpu().float().item())
206
+ self._eps_scalar = float(self.eps.detach().cpu().float().item())
207
+
208
+ self._xielu_cuda_obj = None
209
+ try:
210
+ import xielu.ops # noqa: F401
211
+
212
+ self._xielu_cuda_obj = torch.classes.xielu.XIELU()
213
+ msg = "Using experimental xIELU CUDA."
214
+ try:
215
+ from torch._dynamo import allow_in_graph
216
+
217
+ self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
218
+ msg += " Enabled torch._dynamo for xIELU CUDA."
219
+ except Exception as err:
220
+ msg += (
221
+ f" Could not enable torch._dynamo for xIELU ({err}) - "
222
+ "this may result in slower performance."
223
+ )
224
+ self._xielu_cuda_fn = self._xielu_cuda
225
+ logger.warning_once(msg)
226
+ except Exception as err:
227
+ logger.warning_once(
228
+ "CUDA-fused xIELU not available (%s) –"
229
+ " falling back to a Python version.\n"
230
+ "For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
231
+ str(err),
232
+ )
233
+
234
+ def _xielu_python(self, x: torch.Tensor) -> torch.Tensor:
235
+ alpha_p = nn.functional.softplus(self.alpha_p)
236
+ alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
237
+ return torch.where(
238
+ x > 0,
239
+ alpha_p * x * x + self.beta * x,
240
+ (torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + self.beta * x,
241
+ )
242
+
243
+ def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor:
244
+ """Firewall function to prevent torch.compile from seeing .item()"""
245
+ assert self._xielu_cuda_obj is not None, "XIELU CUDA object must not be None"
246
+ original_shape = x.shape
247
+ # CUDA kernel expects 3D tensors, reshape if needed
248
+ while x.dim() < 3:
249
+ x = x.unsqueeze(0)
250
+ if x.dim() > 3:
251
+ x = x.view(-1, 1, x.size(-1))
252
+ if original_shape != x.shape:
253
+ logger.warning_once(
254
+ "Warning: xIELU input tensor expects 3 dimensions"
255
+ " but got (shape: %s). Reshaping to (shape: %s).\n"
256
+ "Note: For SGLang this may be expected if sending"
257
+ "[B*S,D] instead of [B,S,D].",
258
+ original_shape,
259
+ x.shape,
260
+ )
261
+ result = self._xielu_cuda_obj.forward(
262
+ x,
263
+ self.alpha_p,
264
+ self.alpha_n,
265
+ # Temporary until xIELU CUDA fully implemented -> self.{beta,eps}.item()
266
+ self._beta_scalar,
267
+ self._eps_scalar,
268
+ self.with_vector_loads,
269
+ )
270
+ return result.view(original_shape)
271
+
272
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
273
+ if self._xielu_cuda_obj is not None and input.is_cuda:
274
+ if not torch._dynamo.is_compiling():
275
+ return self._xielu_cuda_fn(input)
276
+ else:
277
+ logger.warning_once(
278
+ "torch._dynamo is compiling, using Python version of xIELU."
279
+ )
280
+ return self._xielu_python(input)
281
+
282
+
153
283
  class ScaledActivation(nn.Module):
154
284
  """An activation function with post-scale parameters.
155
285
 
@@ -197,6 +327,7 @@ _ACTIVATION_REGISTRY = {
197
327
  "gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
198
328
  "gelu_new": NewGELU(),
199
329
  "relu2": ReLU2(),
330
+ "xielu": XIELU(),
200
331
  }
201
332
 
202
333
 
@@ -242,7 +373,9 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
242
373
  return nn.Identity()
243
374
 
244
375
 
245
- if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
376
+ if not (
377
+ _is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip or _is_xpu
378
+ ):
246
379
  logger.info(
247
380
  "sgl-kernel is not available on Non-NV, Non-AMD platforms or Non-AMX CPUs. Fallback to other kernel libraries."
248
381
  )
@@ -18,7 +18,10 @@ import triton.language as tl
18
18
  from sglang.global_config import global_config
19
19
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
20
20
  from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
21
- from sglang.srt.layers.dp_attention import get_attention_tp_size
21
+ from sglang.srt.layers.dp_attention import (
22
+ get_attention_tp_size,
23
+ is_dp_attention_enabled,
24
+ )
22
25
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
23
26
 
24
27
  if TYPE_CHECKING:
@@ -154,6 +157,8 @@ class AiterAttnBackend(AttentionBackend):
154
157
  (max_bs + 1,), dtype=torch.int32, device=model_runner.device
155
158
  )
156
159
 
160
+ self.enable_dp_attention = is_dp_attention_enabled()
161
+
157
162
  def init_forward_metadata(self, forward_batch: ForwardBatch):
158
163
  """Init auxiliary variables for triton attention backend."""
159
164
 
@@ -302,19 +307,19 @@ class AiterAttnBackend(AttentionBackend):
302
307
  if self.use_mla:
303
308
  self.mla_indices_updater_prefill.update(
304
309
  forward_batch.req_pool_indices,
305
- forward_batch.extend_prefix_lens,
306
- sum(forward_batch.extend_prefix_lens_cpu),
310
+ forward_batch.seq_lens,
311
+ forward_batch.seq_lens_sum,
307
312
  forward_batch.extend_seq_lens,
308
- max(forward_batch.extend_seq_lens_cpu),
309
- forward_batch.seq_lens_cpu.max().item(),
313
+ forward_batch.extend_seq_lens.max().item(),
314
+ forward_batch.seq_lens.max().item(),
310
315
  spec_info=None,
311
316
  )
312
- self.mla_indices_updater_prefill.kv_indptr += (
313
- self.mla_indices_updater_prefill.qo_indptr
314
- )
317
+
318
+ kv_indices = self.mla_indices_updater_prefill.kv_indices
319
+
315
320
  self.forward_metadata = ForwardMetadata(
316
321
  self.mla_indices_updater_prefill.kv_indptr,
317
- self.mla_indices_updater_prefill.kv_indices,
322
+ kv_indices,
318
323
  self.mla_indices_updater_prefill.qo_indptr,
319
324
  self.kv_last_page_len[:bs],
320
325
  self.mla_indices_updater_prefill.max_q_len,
@@ -614,66 +619,86 @@ class AiterAttnBackend(AttentionBackend):
614
619
  assert len(k.shape) == 3
615
620
  assert len(v.shape) == 3
616
621
 
617
- if kv_indices.shape[0] == 0:
618
- o = flash_attn_varlen_func(
619
- q,
620
- k,
621
- v,
622
- qo_indptr,
623
- qo_indptr,
624
- max_q_len,
625
- max_q_len,
626
- softmax_scale=layer.scaling,
627
- causal=True,
628
- )
629
- return o
630
- elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim):
631
- K_Buffer = torch.index_select(K_Buffer, 0, kv_indices)
632
- kvc, k_pe = torch.split(
633
- K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1
634
- )
635
- kvprefix = layer.kv_b_proj(kvc.contiguous())[0]
622
+ if forward_batch.forward_mode.is_extend():
623
+ if kv_indices.shape[0] == 0:
624
+ o = flash_attn_varlen_func(
625
+ q,
626
+ k,
627
+ v,
628
+ qo_indptr,
629
+ qo_indptr,
630
+ max_q_len,
631
+ max_q_len,
632
+ softmax_scale=layer.scaling,
633
+ causal=True,
634
+ )
635
+ return o
636
+ elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim):
637
+ K_Buffer = torch.index_select(K_Buffer, 0, kv_indices)
638
+ kvc, k_pe = torch.split(
639
+ K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1
640
+ )
641
+ kvprefix = layer.kv_b_proj(kvc.contiguous())[0]
636
642
 
637
- kvprefix = kvprefix.view(
638
- -1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim
639
- )
640
- k_prefix, v_prefix = torch.split(
641
- kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1
642
- )
643
- k_prefix = torch.cat(
644
- [
645
- k_prefix,
646
- torch.broadcast_to(
647
- k_pe,
648
- (k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]),
649
- ),
650
- ],
651
- dim=-1,
652
- )
653
- assert (
654
- forward_batch.extend_prefix_lens.shape
655
- == forward_batch.extend_seq_lens.shape
656
- )
657
- k_prefix = torch.split(k_prefix, forward_batch.extend_prefix_lens_cpu)
658
- k_extend = torch.split(k, forward_batch.extend_seq_lens_cpu)
659
- assert len(k_prefix) == len(forward_batch.extend_prefix_lens_cpu)
660
- k = torch.cat([x for el in zip(k_prefix, k_extend) for x in el])
661
- v_prefix = torch.split(v_prefix, forward_batch.extend_prefix_lens_cpu)
662
- v_extend = torch.split(v, forward_batch.extend_seq_lens_cpu)
663
- v = torch.cat([x for el in zip(v_prefix, v_extend) for x in el])
664
-
665
- o = flash_attn_varlen_func(
666
- q,
667
- k,
668
- v,
669
- qo_indptr,
670
- kv_indptr,
671
- max_q_len,
672
- max_kv_len,
673
- softmax_scale=layer.scaling,
674
- causal=True,
675
- )
676
- return o
643
+ kvprefix = kvprefix.view(
644
+ -1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim
645
+ )
646
+ k_prefix, v_prefix = torch.split(
647
+ kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1
648
+ )
649
+ k_prefix = torch.cat(
650
+ [
651
+ k_prefix,
652
+ torch.broadcast_to(
653
+ k_pe,
654
+ (k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]),
655
+ ),
656
+ ],
657
+ dim=-1,
658
+ )
659
+ assert (
660
+ forward_batch.extend_prefix_lens.shape
661
+ == forward_batch.extend_seq_lens.shape
662
+ )
663
+
664
+ k = k_prefix
665
+ v = v_prefix
666
+
667
+ o = flash_attn_varlen_func(
668
+ q,
669
+ k,
670
+ v,
671
+ qo_indptr,
672
+ kv_indptr,
673
+ max_q_len,
674
+ max_kv_len,
675
+ softmax_scale=layer.scaling,
676
+ causal=True,
677
+ )
678
+ return o
679
+
680
+ else:
681
+ if layer.qk_head_dim != layer.v_head_dim:
682
+ o = q.new_empty(
683
+ (q.shape[0], layer.tp_q_head_num * layer.v_head_dim)
684
+ )
685
+ else:
686
+ o = torch.empty_like(q)
687
+
688
+ mla_prefill_fwd(
689
+ q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
690
+ K_Buffer.view(-1, 1, 1, layer.qk_head_dim),
691
+ o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
692
+ qo_indptr,
693
+ kv_indptr,
694
+ kv_indices,
695
+ self.forward_metadata.kv_last_page_len,
696
+ self.forward_metadata.max_q_len,
697
+ layer.scaling,
698
+ layer.logit_cap,
699
+ )
700
+ K_Buffer = K_Buffer.view(-1, layer.tp_k_head_num, layer.qk_head_dim)
701
+ return o
677
702
  elif forward_batch.forward_mode.is_target_verify():
678
703
  o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim))
679
704
  mla_decode_fwd(
@@ -10,6 +10,7 @@ from torch.nn.functional import scaled_dot_product_attention
10
10
  from sglang.srt.configs.model_config import AttentionArch
11
11
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
12
12
  from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
13
+ from sglang.srt.layers.dp_attention import get_attention_tp_size
13
14
  from sglang.srt.layers.radix_attention import AttentionType
14
15
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
15
16
  from sglang.srt.utils import get_bool_env_var
@@ -33,6 +34,7 @@ class ForwardMetadata:
33
34
  extend_seq_lens_cpu_int: Optional[torch.Tensor] = None
34
35
  seq_lens_cpu_int: Optional[torch.Tensor] = None
35
36
  seq_lens_cpu_list: Optional[List[int]] = None
37
+ seq_lens_list_cumsum: Optional[List[int]] = None
36
38
 
37
39
 
38
40
  class AscendAttnBackend(AttentionBackend):
@@ -83,6 +85,7 @@ class AscendAttnBackend(AttentionBackend):
83
85
 
84
86
  def init_forward_metadata(self, forward_batch: ForwardBatch):
85
87
  """Init the metadata for a forward pass."""
88
+ tp_size = get_attention_tp_size()
86
89
  self.forward_metadata = ForwardMetadata()
87
90
 
88
91
  self.forward_metadata.block_tables = (
@@ -96,9 +99,13 @@ class AscendAttnBackend(AttentionBackend):
96
99
  forward_batch.extend_seq_lens.cpu().int()
97
100
  )
98
101
  self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
99
- self.forward_metadata.seq_lens_list_cumsum = np.cumsum(
100
- forward_batch.extend_seq_lens_cpu
101
- )
102
+
103
+ seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
104
+ if forward_batch.is_extend_in_batch:
105
+ seq_lens_list_cumsum[-1] = (
106
+ (seq_lens_list_cumsum[-1] - 1) // tp_size + 1
107
+ ) * tp_size
108
+ self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
102
109
 
103
110
  self.graph_mode = False
104
111
 
@@ -368,7 +375,7 @@ class AscendAttnBackend(AttentionBackend):
368
375
  -1, layer.tp_v_head_num, self.page_size, self.kv_lora_rank
369
376
  )
370
377
 
371
- q_nope = q.view(-1, layer.tp_q_head_num, 1, self.kv_lora_rank)
378
+ q_nope = q.view(-1, layer.tp_q_head_num, 1, self.kv_lora_rank).contiguous()
372
379
  q_rope = q_rope.view(-1, layer.tp_q_head_num, 1, self.qk_rope_head_dim)
373
380
  if self.forward_metadata.seq_lens_cpu_int is None:
374
381
  actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list