sglang 0.5.2rc1__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/lang/interpreter.py +1 -1
  7. sglang/launch_server.py +14 -0
  8. sglang/profiler.py +2 -2
  9. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  10. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  11. sglang/srt/configs/__init__.py +8 -0
  12. sglang/srt/configs/device_config.py +3 -1
  13. sglang/srt/configs/dots_ocr.py +64 -0
  14. sglang/srt/configs/dots_vlm.py +139 -0
  15. sglang/srt/configs/falcon_h1.py +360 -0
  16. sglang/srt/configs/internvl.py +6 -0
  17. sglang/srt/configs/load_config.py +9 -0
  18. sglang/srt/configs/model_config.py +181 -82
  19. sglang/srt/configs/qwen3_next.py +326 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +71 -19
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +326 -53
  44. sglang/srt/disaggregation/prefill.py +36 -17
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +192 -113
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  52. sglang/srt/entrypoints/grpc_server.py +810 -0
  53. sglang/srt/entrypoints/http_server.py +132 -57
  54. sglang/srt/entrypoints/openai/protocol.py +115 -7
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +207 -58
  57. sglang/srt/entrypoints/openai/serving_completions.py +17 -4
  58. sglang/srt/entrypoints/openai/serving_embedding.py +10 -4
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +49 -4
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/environ.py +285 -0
  63. sglang/srt/eplb/eplb_manager.py +2 -2
  64. sglang/srt/eplb/expert_distribution.py +26 -13
  65. sglang/srt/eplb/expert_location.py +38 -8
  66. sglang/srt/eplb/expert_location_updater.py +1 -1
  67. sglang/srt/function_call/base_format_detector.py +3 -6
  68. sglang/srt/function_call/ebnf_composer.py +11 -9
  69. sglang/srt/function_call/function_call_parser.py +9 -2
  70. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  71. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  72. sglang/srt/function_call/json_array_parser.py +63 -0
  73. sglang/srt/function_call/kimik2_detector.py +17 -4
  74. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  75. sglang/srt/function_call/utils.py +96 -5
  76. sglang/srt/grpc/__init__.py +1 -0
  77. sglang/srt/grpc/compile_proto.py +245 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  81. sglang/srt/layers/activation.py +143 -9
  82. sglang/srt/layers/attention/aiter_backend.py +106 -82
  83. sglang/srt/layers/attention/ascend_backend.py +115 -9
  84. sglang/srt/layers/attention/attention_registry.py +206 -0
  85. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  86. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  87. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  88. sglang/srt/layers/attention/fla/chunk.py +242 -0
  89. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  90. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  91. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  92. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  93. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  94. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  95. sglang/srt/layers/attention/fla/index.py +37 -0
  96. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  97. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  98. sglang/srt/layers/attention/fla/op.py +66 -0
  99. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  100. sglang/srt/layers/attention/fla/utils.py +331 -0
  101. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  102. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  103. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  104. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  105. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  106. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  107. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  108. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  109. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  111. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  112. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  113. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  114. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  115. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  121. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  122. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  123. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  124. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  125. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  126. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  127. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  128. sglang/srt/layers/attention/nsa/utils.py +24 -0
  129. sglang/srt/layers/attention/nsa_backend.py +887 -0
  130. sglang/srt/layers/attention/tbo_backend.py +6 -6
  131. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  132. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  133. sglang/srt/layers/attention/triton_backend.py +57 -7
  134. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  135. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  136. sglang/srt/layers/attention/vision.py +58 -0
  137. sglang/srt/layers/attention/wave_backend.py +4 -4
  138. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  139. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  140. sglang/srt/layers/communicator.py +53 -7
  141. sglang/srt/layers/dp_attention.py +41 -2
  142. sglang/srt/layers/elementwise.py +3 -1
  143. sglang/srt/layers/layernorm.py +34 -15
  144. sglang/srt/layers/linear.py +55 -7
  145. sglang/srt/layers/logits_processor.py +44 -12
  146. sglang/srt/layers/moe/__init__.py +2 -1
  147. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  148. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  149. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  150. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  151. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  167. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  169. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  170. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  171. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  172. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  173. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  174. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  175. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  176. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  177. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  178. sglang/srt/layers/moe/topk.py +30 -9
  179. sglang/srt/layers/moe/utils.py +22 -7
  180. sglang/srt/layers/parameter.py +23 -6
  181. sglang/srt/layers/quantization/awq.py +19 -7
  182. sglang/srt/layers/quantization/base_config.py +11 -6
  183. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  184. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  185. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  186. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  187. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  188. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  189. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  190. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  191. sglang/srt/layers/quantization/fp8.py +78 -49
  192. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  193. sglang/srt/layers/quantization/gptq.py +25 -17
  194. sglang/srt/layers/quantization/modelopt_quant.py +225 -57
  195. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  196. sglang/srt/layers/quantization/mxfp4.py +77 -42
  197. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  199. sglang/srt/layers/quantization/quark/utils.py +97 -0
  200. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  201. sglang/srt/layers/quantization/unquant.py +135 -47
  202. sglang/srt/layers/quantization/w4afp8.py +26 -17
  203. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  204. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  205. sglang/srt/layers/rocm_linear_utils.py +44 -0
  206. sglang/srt/layers/rotary_embedding.py +78 -49
  207. sglang/srt/layers/sampler.py +213 -21
  208. sglang/srt/layers/utils.py +23 -0
  209. sglang/srt/lora/backend/base_backend.py +50 -8
  210. sglang/srt/lora/backend/chunked_backend.py +348 -0
  211. sglang/srt/lora/backend/triton_backend.py +99 -5
  212. sglang/srt/lora/layers.py +32 -0
  213. sglang/srt/lora/lora.py +8 -3
  214. sglang/srt/lora/lora_manager.py +52 -118
  215. sglang/srt/lora/mem_pool.py +25 -11
  216. sglang/srt/lora/triton_ops/__init__.py +4 -0
  217. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  219. sglang/srt/lora/utils.py +22 -11
  220. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  221. sglang/srt/managers/cache_controller.py +215 -314
  222. sglang/srt/managers/data_parallel_controller.py +115 -80
  223. sglang/srt/managers/detokenizer_manager.py +19 -15
  224. sglang/srt/managers/disagg_service.py +46 -0
  225. sglang/srt/managers/io_struct.py +340 -109
  226. sglang/srt/managers/mm_utils.py +44 -6
  227. sglang/srt/managers/multi_tokenizer_mixin.py +358 -404
  228. sglang/srt/managers/multimodal_processor.py +1 -2
  229. sglang/srt/managers/overlap_utils.py +53 -0
  230. sglang/srt/managers/schedule_batch.py +240 -138
  231. sglang/srt/managers/schedule_policy.py +147 -19
  232. sglang/srt/managers/scheduler.py +501 -304
  233. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  234. sglang/srt/managers/scheduler_metrics_mixin.py +119 -40
  235. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  236. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  237. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  238. sglang/srt/managers/template_manager.py +3 -3
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +321 -632
  241. sglang/srt/managers/tp_worker.py +81 -22
  242. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  243. sglang/srt/managers/utils.py +1 -45
  244. sglang/srt/mem_cache/allocator.py +15 -21
  245. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  246. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  247. sglang/srt/mem_cache/chunk_cache.py +8 -1
  248. sglang/srt/mem_cache/evict_policy.py +23 -0
  249. sglang/srt/mem_cache/hicache_storage.py +58 -34
  250. sglang/srt/mem_cache/hiradix_cache.py +227 -80
  251. sglang/srt/mem_cache/memory_pool.py +535 -58
  252. sglang/srt/mem_cache/memory_pool_host.py +239 -223
  253. sglang/srt/mem_cache/radix_cache.py +222 -73
  254. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  255. sglang/srt/mem_cache/storage/__init__.py +10 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  257. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  258. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  259. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  260. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  261. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  262. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  263. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +268 -63
  264. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  265. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  266. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +198 -30
  267. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  268. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  269. sglang/srt/metrics/collector.py +519 -132
  270. sglang/srt/metrics/func_timer.py +2 -7
  271. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  272. sglang/srt/metrics/utils.py +55 -0
  273. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  274. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  275. sglang/srt/model_executor/forward_batch_info.py +98 -57
  276. sglang/srt/model_executor/model_runner.py +433 -158
  277. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  278. sglang/srt/model_loader/__init__.py +9 -3
  279. sglang/srt/model_loader/loader.py +133 -5
  280. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  281. sglang/srt/model_loader/weight_utils.py +158 -3
  282. sglang/srt/models/apertus.py +686 -0
  283. sglang/srt/models/bailing_moe.py +820 -217
  284. sglang/srt/models/bailing_moe_nextn.py +168 -0
  285. sglang/srt/models/deepseek_nextn.py +6 -1
  286. sglang/srt/models/deepseek_v2.py +833 -152
  287. sglang/srt/models/dots_ocr.py +173 -0
  288. sglang/srt/models/dots_vlm.py +174 -0
  289. sglang/srt/models/dots_vlm_vit.py +337 -0
  290. sglang/srt/models/ernie4.py +1 -1
  291. sglang/srt/models/falcon_h1.py +576 -0
  292. sglang/srt/models/gemma3_causal.py +0 -2
  293. sglang/srt/models/gemma3_mm.py +1 -1
  294. sglang/srt/models/gemma3n_mm.py +2 -2
  295. sglang/srt/models/glm4_moe.py +14 -5
  296. sglang/srt/models/glm4_moe_nextn.py +2 -2
  297. sglang/srt/models/glm4v.py +5 -3
  298. sglang/srt/models/glm4v_moe.py +4 -1
  299. sglang/srt/models/gpt_oss.py +8 -31
  300. sglang/srt/models/internvl.py +28 -0
  301. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  302. sglang/srt/models/llama.py +4 -0
  303. sglang/srt/models/llama4.py +9 -0
  304. sglang/srt/models/llama_eagle3.py +13 -0
  305. sglang/srt/models/longcat_flash.py +3 -3
  306. sglang/srt/models/longcat_flash_nextn.py +1 -1
  307. sglang/srt/models/minicpmv.py +165 -3
  308. sglang/srt/models/mllama4.py +40 -4
  309. sglang/srt/models/opt.py +637 -0
  310. sglang/srt/models/qwen2_5_vl.py +29 -5
  311. sglang/srt/models/qwen2_audio.py +1 -1
  312. sglang/srt/models/qwen2_moe.py +124 -14
  313. sglang/srt/models/qwen2_vl.py +1 -1
  314. sglang/srt/models/qwen3.py +26 -5
  315. sglang/srt/models/qwen3_moe.py +71 -12
  316. sglang/srt/models/qwen3_next.py +1069 -0
  317. sglang/srt/models/qwen3_next_mtp.py +112 -0
  318. sglang/srt/models/qwen3_vl.py +787 -0
  319. sglang/srt/models/qwen3_vl_moe.py +471 -0
  320. sglang/srt/models/registry.py +15 -3
  321. sglang/srt/models/sarashina2_vision.py +269 -0
  322. sglang/srt/models/solar.py +505 -0
  323. sglang/srt/models/starcoder2.py +357 -0
  324. sglang/srt/models/step3_vl.py +1 -1
  325. sglang/srt/models/torch_native_llama.py +10 -3
  326. sglang/srt/models/utils.py +51 -0
  327. sglang/srt/multimodal/processors/base_processor.py +15 -7
  328. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  329. sglang/srt/multimodal/processors/glm4v.py +9 -9
  330. sglang/srt/multimodal/processors/internvl.py +153 -129
  331. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  332. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  333. sglang/srt/offloader.py +27 -3
  334. sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +6 -0
  335. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  336. sglang/srt/sampling/sampling_batch_info.py +38 -17
  337. sglang/srt/sampling/sampling_params.py +7 -0
  338. sglang/srt/server_args.py +1030 -254
  339. sglang/srt/server_args_config_parser.py +146 -0
  340. sglang/srt/single_batch_overlap.py +151 -0
  341. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  342. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  343. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  344. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  345. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  346. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  347. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  348. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  349. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  350. sglang/srt/speculative/eagle_worker.py +253 -136
  351. sglang/srt/speculative/ngram_utils.py +428 -0
  352. sglang/srt/speculative/ngram_worker.py +245 -0
  353. sglang/srt/speculative/spec_info.py +52 -0
  354. sglang/srt/speculative/spec_utils.py +606 -0
  355. sglang/srt/speculative/standalone_worker.py +109 -0
  356. sglang/srt/torch_memory_saver_adapter.py +5 -7
  357. sglang/srt/tracing/trace.py +578 -0
  358. sglang/srt/two_batch_overlap.py +8 -5
  359. sglang/srt/utils/__init__.py +2 -0
  360. sglang/srt/{utils.py → utils/common.py} +445 -77
  361. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  362. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  363. sglang/srt/utils/rpd_utils.py +452 -0
  364. sglang/srt/utils/slow_rank_detector.py +71 -0
  365. sglang/srt/warmup.py +8 -4
  366. sglang/srt/weight_sync/utils.py +2 -2
  367. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  368. sglang/test/few_shot_gsm8k.py +1 -0
  369. sglang/test/get_logits_ut.py +57 -0
  370. sglang/test/run_eval.py +79 -11
  371. sglang/test/runners.py +5 -1
  372. sglang/test/simple_eval_common.py +5 -2
  373. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  374. sglang/test/test_block_fp8.py +2 -2
  375. sglang/test/test_cutlass_moe.py +24 -6
  376. sglang/test/test_deterministic.py +297 -0
  377. sglang/test/test_disaggregation_utils.py +77 -0
  378. sglang/test/test_fp4_moe.py +370 -1
  379. sglang/test/test_programs.py +1 -1
  380. sglang/test/test_utils.py +383 -5
  381. sglang/utils.py +22 -1
  382. sglang/version.py +1 -1
  383. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  384. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/RECORD +392 -258
  385. sglang/srt/disaggregation/launch_lb.py +0 -118
  386. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  387. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  388. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  389. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  390. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  391. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  392. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  393. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  394. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  395. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,239 @@
1
+ # This file is auto-generated. Do not edit manually.
2
+ # Regenerate with: python compile_proto.py
3
+
4
+ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
5
+ """Client and server classes corresponding to protobuf-defined services."""
6
+ import grpc
7
+ import warnings
8
+
9
+ from . import sglang_scheduler_pb2 as sglang__scheduler__pb2
10
+
11
+ GRPC_GENERATED_VERSION = '1.74.0'
12
+ GRPC_VERSION = grpc.__version__
13
+ _version_not_supported = False
14
+
15
+ try:
16
+ from grpc._utilities import first_version_is_lower
17
+ _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
18
+ except ImportError:
19
+ _version_not_supported = True
20
+
21
+ if _version_not_supported:
22
+ raise RuntimeError(
23
+ f'The grpc package installed is at version {GRPC_VERSION},'
24
+ + f' but the generated code in sglang_scheduler_pb2_grpc.py depends on'
25
+ + f' grpcio>={GRPC_GENERATED_VERSION}.'
26
+ + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
27
+ + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
28
+ )
29
+
30
+
31
+ class SglangSchedulerStub(object):
32
+ """Service definition for SGLang scheduler communication
33
+ This protocol bridges the Rust router and Python scheduler
34
+ """
35
+
36
+ def __init__(self, channel):
37
+ """Constructor.
38
+
39
+ Args:
40
+ channel: A grpc.Channel.
41
+ """
42
+ self.Generate = channel.unary_stream(
43
+ '/sglang.grpc.scheduler.SglangScheduler/Generate',
44
+ request_serializer=sglang__scheduler__pb2.GenerateRequest.SerializeToString,
45
+ response_deserializer=sglang__scheduler__pb2.GenerateResponse.FromString,
46
+ _registered_method=True)
47
+ self.Embed = channel.unary_unary(
48
+ '/sglang.grpc.scheduler.SglangScheduler/Embed',
49
+ request_serializer=sglang__scheduler__pb2.EmbedRequest.SerializeToString,
50
+ response_deserializer=sglang__scheduler__pb2.EmbedResponse.FromString,
51
+ _registered_method=True)
52
+ self.HealthCheck = channel.unary_unary(
53
+ '/sglang.grpc.scheduler.SglangScheduler/HealthCheck',
54
+ request_serializer=sglang__scheduler__pb2.HealthCheckRequest.SerializeToString,
55
+ response_deserializer=sglang__scheduler__pb2.HealthCheckResponse.FromString,
56
+ _registered_method=True)
57
+ self.Abort = channel.unary_unary(
58
+ '/sglang.grpc.scheduler.SglangScheduler/Abort',
59
+ request_serializer=sglang__scheduler__pb2.AbortRequest.SerializeToString,
60
+ response_deserializer=sglang__scheduler__pb2.AbortResponse.FromString,
61
+ _registered_method=True)
62
+
63
+
64
+ class SglangSchedulerServicer(object):
65
+ """Service definition for SGLang scheduler communication
66
+ This protocol bridges the Rust router and Python scheduler
67
+ """
68
+
69
+ def Generate(self, request, context):
70
+ """Submit a generation request (supports streaming)
71
+ """
72
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
73
+ context.set_details('Method not implemented!')
74
+ raise NotImplementedError('Method not implemented!')
75
+
76
+ def Embed(self, request, context):
77
+ """Submit an embedding request
78
+ """
79
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
80
+ context.set_details('Method not implemented!')
81
+ raise NotImplementedError('Method not implemented!')
82
+
83
+ def HealthCheck(self, request, context):
84
+ """Health check and metrics
85
+ """
86
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
87
+ context.set_details('Method not implemented!')
88
+ raise NotImplementedError('Method not implemented!')
89
+
90
+ def Abort(self, request, context):
91
+ """Abort a running request
92
+ """
93
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
94
+ context.set_details('Method not implemented!')
95
+ raise NotImplementedError('Method not implemented!')
96
+
97
+
98
+ def add_SglangSchedulerServicer_to_server(servicer, server):
99
+ rpc_method_handlers = {
100
+ 'Generate': grpc.unary_stream_rpc_method_handler(
101
+ servicer.Generate,
102
+ request_deserializer=sglang__scheduler__pb2.GenerateRequest.FromString,
103
+ response_serializer=sglang__scheduler__pb2.GenerateResponse.SerializeToString,
104
+ ),
105
+ 'Embed': grpc.unary_unary_rpc_method_handler(
106
+ servicer.Embed,
107
+ request_deserializer=sglang__scheduler__pb2.EmbedRequest.FromString,
108
+ response_serializer=sglang__scheduler__pb2.EmbedResponse.SerializeToString,
109
+ ),
110
+ 'HealthCheck': grpc.unary_unary_rpc_method_handler(
111
+ servicer.HealthCheck,
112
+ request_deserializer=sglang__scheduler__pb2.HealthCheckRequest.FromString,
113
+ response_serializer=sglang__scheduler__pb2.HealthCheckResponse.SerializeToString,
114
+ ),
115
+ 'Abort': grpc.unary_unary_rpc_method_handler(
116
+ servicer.Abort,
117
+ request_deserializer=sglang__scheduler__pb2.AbortRequest.FromString,
118
+ response_serializer=sglang__scheduler__pb2.AbortResponse.SerializeToString,
119
+ ),
120
+ }
121
+ generic_handler = grpc.method_handlers_generic_handler(
122
+ 'sglang.grpc.scheduler.SglangScheduler', rpc_method_handlers)
123
+ server.add_generic_rpc_handlers((generic_handler,))
124
+ server.add_registered_method_handlers('sglang.grpc.scheduler.SglangScheduler', rpc_method_handlers)
125
+
126
+
127
+ # This class is part of an EXPERIMENTAL API.
128
+ class SglangScheduler(object):
129
+ """Service definition for SGLang scheduler communication
130
+ This protocol bridges the Rust router and Python scheduler
131
+ """
132
+
133
+ @staticmethod
134
+ def Generate(request,
135
+ target,
136
+ options=(),
137
+ channel_credentials=None,
138
+ call_credentials=None,
139
+ insecure=False,
140
+ compression=None,
141
+ wait_for_ready=None,
142
+ timeout=None,
143
+ metadata=None):
144
+ return grpc.experimental.unary_stream(
145
+ request,
146
+ target,
147
+ '/sglang.grpc.scheduler.SglangScheduler/Generate',
148
+ sglang__scheduler__pb2.GenerateRequest.SerializeToString,
149
+ sglang__scheduler__pb2.GenerateResponse.FromString,
150
+ options,
151
+ channel_credentials,
152
+ insecure,
153
+ call_credentials,
154
+ compression,
155
+ wait_for_ready,
156
+ timeout,
157
+ metadata,
158
+ _registered_method=True)
159
+
160
+ @staticmethod
161
+ def Embed(request,
162
+ target,
163
+ options=(),
164
+ channel_credentials=None,
165
+ call_credentials=None,
166
+ insecure=False,
167
+ compression=None,
168
+ wait_for_ready=None,
169
+ timeout=None,
170
+ metadata=None):
171
+ return grpc.experimental.unary_unary(
172
+ request,
173
+ target,
174
+ '/sglang.grpc.scheduler.SglangScheduler/Embed',
175
+ sglang__scheduler__pb2.EmbedRequest.SerializeToString,
176
+ sglang__scheduler__pb2.EmbedResponse.FromString,
177
+ options,
178
+ channel_credentials,
179
+ insecure,
180
+ call_credentials,
181
+ compression,
182
+ wait_for_ready,
183
+ timeout,
184
+ metadata,
185
+ _registered_method=True)
186
+
187
+ @staticmethod
188
+ def HealthCheck(request,
189
+ target,
190
+ options=(),
191
+ channel_credentials=None,
192
+ call_credentials=None,
193
+ insecure=False,
194
+ compression=None,
195
+ wait_for_ready=None,
196
+ timeout=None,
197
+ metadata=None):
198
+ return grpc.experimental.unary_unary(
199
+ request,
200
+ target,
201
+ '/sglang.grpc.scheduler.SglangScheduler/HealthCheck',
202
+ sglang__scheduler__pb2.HealthCheckRequest.SerializeToString,
203
+ sglang__scheduler__pb2.HealthCheckResponse.FromString,
204
+ options,
205
+ channel_credentials,
206
+ insecure,
207
+ call_credentials,
208
+ compression,
209
+ wait_for_ready,
210
+ timeout,
211
+ metadata,
212
+ _registered_method=True)
213
+
214
+ @staticmethod
215
+ def Abort(request,
216
+ target,
217
+ options=(),
218
+ channel_credentials=None,
219
+ call_credentials=None,
220
+ insecure=False,
221
+ compression=None,
222
+ wait_for_ready=None,
223
+ timeout=None,
224
+ metadata=None):
225
+ return grpc.experimental.unary_unary(
226
+ request,
227
+ target,
228
+ '/sglang.grpc.scheduler.SglangScheduler/Abort',
229
+ sglang__scheduler__pb2.AbortRequest.SerializeToString,
230
+ sglang__scheduler__pb2.AbortResponse.FromString,
231
+ options,
232
+ channel_credentials,
233
+ insecure,
234
+ call_credentials,
235
+ compression,
236
+ wait_for_ready,
237
+ timeout,
238
+ metadata,
239
+ _registered_method=True)
@@ -35,6 +35,7 @@ from sglang.srt.utils import (
35
35
  is_cuda,
36
36
  is_hip,
37
37
  is_npu,
38
+ is_xpu,
38
39
  set_weight_attrs,
39
40
  )
40
41
  from sglang.utils import resolve_obj_by_qualname
@@ -44,8 +45,9 @@ _is_npu = is_npu()
44
45
  _is_cpu_amx_available = cpu_has_amx_support()
45
46
  _is_cpu = is_cpu()
46
47
  _is_hip = is_hip()
48
+ _is_xpu = is_xpu()
47
49
 
48
- if _is_cuda:
50
+ if _is_cuda or _is_xpu:
49
51
  from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
50
52
  elif _is_hip:
51
53
  from sgl_kernel import gelu_and_mul, gelu_quick, gelu_tanh_and_mul, silu_and_mul
@@ -70,8 +72,6 @@ class SiluAndMul(CustomOp):
70
72
 
71
73
  def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
72
74
  if _is_cpu_amx_available:
73
- d = x.shape[-1] // 2
74
- output_shape = x.shape[:-1] + (d,)
75
75
  out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)
76
76
  return out
77
77
  else:
@@ -81,17 +81,20 @@ class SiluAndMul(CustomOp):
81
81
  out = torch_npu.npu_swiglu(x)
82
82
  return out
83
83
 
84
+ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
85
+ d = x.shape[-1] // 2
86
+ output_shape = x.shape[:-1] + (d,)
87
+ out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
88
+ silu_and_mul(x, out)
89
+ return out
90
+
84
91
 
85
92
  class GeluAndMul(CustomOp):
86
93
  def __init__(self, approximate="tanh"):
87
94
  super().__init__()
88
95
  self.approximate = approximate
89
96
 
90
- def forward_native(self, x: torch.Tensor) -> torch.Tensor:
91
- d = x.shape[-1] // 2
92
- return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
93
-
94
- def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
97
+ def _forward_impl(self, x: torch.Tensor) -> torch.Tensor:
95
98
  d = x.shape[-1] // 2
96
99
  output_shape = x.shape[:-1] + (d,)
97
100
  out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
@@ -103,6 +106,24 @@ class GeluAndMul(CustomOp):
103
106
  raise RuntimeError("GeluAndMul only support tanh or none")
104
107
  return out
105
108
 
109
+ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
110
+ d = x.shape[-1] // 2
111
+ return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
112
+
113
+ def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
114
+ if _is_cpu_amx_available and self.approximate == "tanh":
115
+ return torch.ops.sgl_kernel.gelu_tanh_and_mul_cpu(x)
116
+ elif _is_cpu_amx_available and self.approximate == "none":
117
+ return torch.ops.sgl_kernel.gelu_and_mul_cpu(x)
118
+ else:
119
+ return self.forward_native(x)
120
+
121
+ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
122
+ return self._forward_impl(x)
123
+
124
+ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
125
+ return self._forward_impl(x)
126
+
106
127
  def forward_npu(self, x: torch.Tensor) -> torch.Tensor:
107
128
  y_npu, gelu_npu = torch_npu.npu_geglu(
108
129
  x,
@@ -150,6 +171,116 @@ class QuickGELU(CustomOp):
150
171
  return torch_npu.npu_fast_gelu(x)
151
172
 
152
173
 
174
+ class XIELU(CustomOp):
175
+ """
176
+ Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
177
+ If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA
178
+ Otherwise, we emit a single warning and use xIELU Python
179
+ """
180
+
181
+ def __init__(
182
+ self,
183
+ alpha_p_init: float = 0.8,
184
+ alpha_n_init: float = 0.8,
185
+ beta: float = 0.5,
186
+ eps: float = -1e-6,
187
+ dtype: torch.dtype = torch.bfloat16,
188
+ with_vector_loads: bool = False,
189
+ ):
190
+ super().__init__()
191
+ self.alpha_p = nn.Parameter(
192
+ torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) - 1).unsqueeze(
193
+ 0
194
+ )
195
+ )
196
+ self.alpha_n = nn.Parameter(
197
+ torch.log(
198
+ torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) - 1
199
+ ).unsqueeze(0)
200
+ )
201
+ self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
202
+ self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
203
+ self.with_vector_loads = with_vector_loads
204
+ # Temporary until xIELU CUDA fully implemented
205
+ self._beta_scalar = float(self.beta.detach().cpu().float().item())
206
+ self._eps_scalar = float(self.eps.detach().cpu().float().item())
207
+
208
+ self._xielu_cuda_obj = None
209
+ try:
210
+ import xielu.ops # noqa: F401
211
+
212
+ self._xielu_cuda_obj = torch.classes.xielu.XIELU()
213
+ msg = "Using experimental xIELU CUDA."
214
+ try:
215
+ from torch._dynamo import allow_in_graph
216
+
217
+ self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
218
+ msg += " Enabled torch._dynamo for xIELU CUDA."
219
+ except Exception as err:
220
+ msg += (
221
+ f" Could not enable torch._dynamo for xIELU ({err}) - "
222
+ "this may result in slower performance."
223
+ )
224
+ self._xielu_cuda_fn = self._xielu_cuda
225
+ logger.warning_once(msg)
226
+ except Exception as err:
227
+ pass
228
+ # logger.warning_once(
229
+ # "CUDA-fused xIELU not available (%s) –"
230
+ # " falling back to a Python version.\n"
231
+ # "For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
232
+ # str(err),
233
+ # )
234
+
235
+ def _xielu_python(self, x: torch.Tensor) -> torch.Tensor:
236
+ alpha_p = nn.functional.softplus(self.alpha_p)
237
+ alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
238
+ return torch.where(
239
+ x > 0,
240
+ alpha_p * x * x + self.beta * x,
241
+ (torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + self.beta * x,
242
+ )
243
+
244
+ def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor:
245
+ """Firewall function to prevent torch.compile from seeing .item()"""
246
+ assert self._xielu_cuda_obj is not None, "XIELU CUDA object must not be None"
247
+ original_shape = x.shape
248
+ # CUDA kernel expects 3D tensors, reshape if needed
249
+ while x.dim() < 3:
250
+ x = x.unsqueeze(0)
251
+ if x.dim() > 3:
252
+ x = x.view(-1, 1, x.size(-1))
253
+ if original_shape != x.shape:
254
+ logger.warning_once(
255
+ "Warning: xIELU input tensor expects 3 dimensions"
256
+ " but got (shape: %s). Reshaping to (shape: %s).\n"
257
+ "Note: For SGLang this may be expected if sending"
258
+ "[B*S,D] instead of [B,S,D].",
259
+ original_shape,
260
+ x.shape,
261
+ )
262
+ result = self._xielu_cuda_obj.forward(
263
+ x,
264
+ self.alpha_p,
265
+ self.alpha_n,
266
+ # Temporary until xIELU CUDA fully implemented -> self.{beta,eps}.item()
267
+ self._beta_scalar,
268
+ self._eps_scalar,
269
+ self.with_vector_loads,
270
+ )
271
+ return result.view(original_shape)
272
+
273
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
274
+ if self._xielu_cuda_obj is not None and input.is_cuda:
275
+ if not torch._dynamo.is_compiling():
276
+ return self._xielu_cuda_fn(input)
277
+ else:
278
+ logger.warning_once(
279
+ "torch._dynamo is compiling, using Python version of xIELU."
280
+ )
281
+ return self._xielu_python(input)
282
+
283
+
153
284
  class ScaledActivation(nn.Module):
154
285
  """An activation function with post-scale parameters.
155
286
 
@@ -197,6 +328,7 @@ _ACTIVATION_REGISTRY = {
197
328
  "gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
198
329
  "gelu_new": NewGELU(),
199
330
  "relu2": ReLU2(),
331
+ "xielu": XIELU(),
200
332
  }
201
333
 
202
334
 
@@ -242,7 +374,9 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
242
374
  return nn.Identity()
243
375
 
244
376
 
245
- if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
377
+ if not (
378
+ _is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip or _is_xpu
379
+ ):
246
380
  logger.info(
247
381
  "sgl-kernel is not available on Non-NV, Non-AMD platforms or Non-AMX CPUs. Fallback to other kernel libraries."
248
382
  )