sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ from sglang.srt.distributed import (
17
17
  get_tp_group,
18
18
  tensor_model_parallel_all_reduce,
19
19
  )
20
+ from sglang.srt.utils import get_bool_env_var, is_hip
20
21
 
21
22
  if TYPE_CHECKING:
22
23
  from sglang.srt.configs.model_config import ModelConfig
@@ -36,6 +37,9 @@ _LOCAL_ATTN_DP_SIZE: Optional[int] = None
36
37
  _LOCAL_ATTN_DP_RANK: Optional[int] = None
37
38
  _ENABLE_DP_ATTENTION_FLAG: bool = False
38
39
 
40
+ _is_hip = is_hip()
41
+ _USE_ROCM700A_WA = _is_hip and get_bool_env_var("SGLANG_USE_ROCM700A")
42
+
39
43
 
40
44
  class DpPaddingMode(IntEnum):
41
45
 
@@ -51,7 +55,12 @@ class DpPaddingMode(IntEnum):
51
55
  return self == DpPaddingMode.SUM_LEN
52
56
 
53
57
  @classmethod
54
- def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DpPaddingMode:
58
+ def get_dp_padding_mode(
59
+ cls, is_extend_in_batch, global_num_tokens: List[int]
60
+ ) -> DpPaddingMode:
61
+ if is_extend_in_batch:
62
+ return DpPaddingMode.SUM_LEN
63
+
55
64
  # we choose the mode that minimizes the communication cost
56
65
  max_len = max(global_num_tokens)
57
66
  sum_len = sum(global_num_tokens)
@@ -62,7 +71,12 @@ class DpPaddingMode(IntEnum):
62
71
 
63
72
  @classmethod
64
73
  def get_default_mode_in_cuda_graph(cls) -> DpPaddingMode:
65
- return cls.MAX_LEN
74
+ # TODO(kkhuang-amd): noqa, temporary work-around for rocm 7.0.0 alpha
75
+ # it can be safely removed later, once RCCL fixed
76
+ if _USE_ROCM700A_WA:
77
+ return cls.SUM_LEN
78
+ else:
79
+ return cls.MAX_LEN
66
80
 
67
81
 
68
82
  class _DpGatheredBufferWrapper:
@@ -119,6 +133,18 @@ class _DpGatheredBufferWrapper:
119
133
  def get_dp_global_num_tokens(cls) -> List[int]:
120
134
  return cls._global_num_tokens
121
135
 
136
+ @classmethod
137
+ def get_dp_hidden_size(cls) -> int:
138
+ return cls._hidden_size
139
+
140
+ @classmethod
141
+ def get_dp_dtype(cls) -> torch.dtype:
142
+ return cls._dtype
143
+
144
+ @classmethod
145
+ def get_dp_device(cls) -> torch.device:
146
+ return cls._device
147
+
122
148
 
123
149
  def set_dp_buffer_len(
124
150
  global_dp_buffer_len: int,
@@ -150,6 +176,18 @@ def get_dp_global_num_tokens() -> List[int]:
150
176
  return _DpGatheredBufferWrapper.get_dp_global_num_tokens()
151
177
 
152
178
 
179
+ def get_dp_hidden_size() -> int:
180
+ return _DpGatheredBufferWrapper.get_dp_hidden_size()
181
+
182
+
183
+ def get_dp_dtype() -> torch.dtype:
184
+ return _DpGatheredBufferWrapper.get_dp_dtype()
185
+
186
+
187
+ def get_dp_device() -> torch.device:
188
+ return _DpGatheredBufferWrapper.get_dp_device()
189
+
190
+
153
191
  def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
154
192
  if not enable_dp_attention:
155
193
  return tp_rank, tp_size, 0
@@ -225,6 +263,7 @@ def initialize_dp_attention(
225
263
  use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP,
226
264
  use_pymscclpp=False,
227
265
  use_custom_allreduce=False,
266
+ use_torch_symm_mem=False,
228
267
  use_hpu_communicator=False,
229
268
  use_xpu_communicator=False,
230
269
  use_npu_communicator=False,
@@ -187,7 +187,9 @@ fused_dual_residual_rmsnorm_kernel_autotune = rmsnorm_autotune(
187
187
 
188
188
  def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=False):
189
189
  assert len(x.shape) == 2
190
- assert x.shape == residual.shape and x.dtype == residual.dtype
190
+ assert (
191
+ x.shape == residual.shape and x.dtype == residual.dtype
192
+ ), f"{x.shape=} {residual.shape=} {x.dtype=} {residual.dtype=}"
191
193
  output, mid = torch.empty_like(x), torch.empty_like(x)
192
194
  bs, hidden_dim = x.shape
193
195
  if autotune:
@@ -18,6 +18,7 @@ from typing import Optional, Tuple, Union
18
18
 
19
19
  import torch
20
20
  import torch.nn as nn
21
+ from packaging.version import Version
21
22
 
22
23
  from sglang.srt.custom_op import CustomOp
23
24
  from sglang.srt.utils import (
@@ -25,32 +26,38 @@ from sglang.srt.utils import (
25
26
  get_bool_env_var,
26
27
  is_cpu,
27
28
  is_cuda,
29
+ is_flashinfer_available,
28
30
  is_hip,
29
31
  is_npu,
32
+ is_xpu,
30
33
  supports_custom_op,
31
34
  )
32
35
 
33
36
  _is_cuda = is_cuda()
37
+ _is_flashinfer_available = is_flashinfer_available()
34
38
  _is_hip = is_hip()
35
39
  _is_npu = is_npu()
36
40
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
37
41
  _is_cpu_amx_available = cpu_has_amx_support()
38
42
  _is_cpu = is_cpu()
43
+ _is_xpu = is_xpu()
39
44
 
40
45
  if _is_cuda:
41
- from sgl_kernel import (
42
- fused_add_rmsnorm,
43
- gemma_fused_add_rmsnorm,
44
- gemma_rmsnorm,
45
- rmsnorm,
46
- )
46
+ if _is_flashinfer_available:
47
+ from flashinfer.norm import fused_add_rmsnorm
48
+ else:
49
+ from sgl_kernel import fused_add_rmsnorm
50
+ from sgl_kernel import gemma_fused_add_rmsnorm, gemma_rmsnorm, rmsnorm
47
51
 
48
52
  if _use_aiter:
49
53
  from aiter import rmsnorm2d_fwd as rms_norm
50
54
  from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm
51
55
  elif _is_hip:
56
+ import vllm
52
57
  from vllm._custom_ops import fused_add_rms_norm, rms_norm
53
58
 
59
+ _vllm_version = Version(vllm.__version__)
60
+
54
61
  logger = logging.getLogger(__name__)
55
62
 
56
63
  if _is_npu:
@@ -73,6 +80,8 @@ class RMSNorm(CustomOp):
73
80
  )
74
81
  if _use_aiter:
75
82
  self._forward_method = self.forward_aiter
83
+ if get_bool_env_var("SGLANG_ENABLE_DETERMINISTIC_INFERENCE"):
84
+ self._forward_method = self.forward_native
76
85
 
77
86
  def forward_cuda(
78
87
  self,
@@ -127,8 +136,21 @@ class RMSNorm(CustomOp):
127
136
  # NOTE: Remove this if aiter kernel supports discontinuous input
128
137
  x = x.contiguous()
129
138
  if residual is not None:
130
- fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
131
- return x, residual
139
+ if _vllm_version < Version("0.9"):
140
+ fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
141
+ return x, residual
142
+ else:
143
+ residual_out = torch.empty_like(x)
144
+ output = torch.empty_like(x)
145
+ fused_add_rms_norm(
146
+ output,
147
+ x,
148
+ residual_out,
149
+ residual,
150
+ self.weight.data,
151
+ self.variance_epsilon,
152
+ )
153
+ return output, residual_out
132
154
  out = torch.empty_like(x)
133
155
  rms_norm(out, x, self.weight.data, self.variance_epsilon)
134
156
  return out
@@ -271,16 +293,11 @@ class GemmaRMSNorm(CustomOp):
271
293
  x: torch.Tensor,
272
294
  residual: Optional[torch.Tensor] = None,
273
295
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
274
- orig_dtype = x.dtype
275
296
  if residual is not None:
276
297
  x = x + residual
277
298
  residual = x
278
299
 
279
- x = x.float()
280
- variance = torch_npu.mean(torch_npu.pow(x, 2), dim=-1, keepdim=True)
281
- x = x * torch_npu.rsqrt(variance + self.variance_epsilon)
282
- x = x * (1.0 + self.weight.float())
283
- x = x.to(orig_dtype)
300
+ x, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon)
284
301
  return x if residual is None else (x, residual)
285
302
 
286
303
 
@@ -312,7 +329,9 @@ class Gemma3RMSNorm(CustomOp):
312
329
  return f"{tuple(self.weight.shape)}, eps={self.eps}"
313
330
 
314
331
 
315
- if not (_is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available)):
332
+ if not (
333
+ _is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_xpu
334
+ ):
316
335
  logger.info(
317
336
  "sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
318
337
  )
@@ -31,6 +31,7 @@ from sglang.srt.layers.parameter import (
31
31
  _ColumnvLLMParameter,
32
32
  )
33
33
  from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
34
+ from sglang.srt.layers.utils import pad_or_narrow_weight
34
35
  from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs
35
36
 
36
37
  if TYPE_CHECKING:
@@ -235,9 +236,8 @@ class ReplicatedLinear(LinearBase):
235
236
  loaded_weight = loaded_weight[:1]
236
237
  else:
237
238
  raise ValueError(f"{loaded_weight} are not all equal")
238
- assert (
239
- param.size() == loaded_weight.size()
240
- ), f"Loading weight error: param: {param.size()}, loaded_weight: {loaded_weight.size()}"
239
+
240
+ assert param.size() == loaded_weight.size()
241
241
  param.data.copy_(loaded_weight)
242
242
 
243
243
  def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@@ -626,9 +626,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
626
626
  # bitsandbytes loads the weights of the specific portion
627
627
  # no need to narrow here
628
628
  if not use_bitsandbytes_4bit and not self.use_presharded_weights:
629
- loaded_weight = loaded_weight.narrow(
630
- output_dim, start_idx, shard_size
631
- )
629
+ # Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
630
+ end_idx = start_idx + shard_size
631
+ if end_idx > loaded_weight.shape[output_dim]:
632
+ loaded_weight = pad_or_narrow_weight(
633
+ loaded_weight, output_dim, start_idx, shard_size
634
+ )
635
+ else:
636
+ loaded_weight = loaded_weight.narrow(
637
+ output_dim, start_idx, shard_size
638
+ )
632
639
 
633
640
  # Special case for AQLM codebooks.
634
641
  elif is_metadata:
@@ -894,6 +901,35 @@ class QKVParallelLinear(ColumnParallelLinear):
894
901
  )
895
902
  self.weight_loader_v2(param, loaded_weight_shard, shard_id)
896
903
 
904
+ def _load_qkv_block_scale(
905
+ self, param: BasevLLMParameter, loaded_weight: torch.Tensor
906
+ ):
907
+ block_n, _ = self.quant_method.quant_config.weight_block_size
908
+ q_size = self.total_num_heads * self.head_size // block_n
909
+ k_size = self.total_num_kv_heads * self.head_size // block_n
910
+ v_size = self.total_num_kv_heads * self.head_size // block_n
911
+ shard_offsets = [
912
+ # (shard_id, shard_offset, shard_size)
913
+ ("q", 0, q_size),
914
+ ("k", q_size, k_size),
915
+ ("v", q_size + k_size, v_size),
916
+ ]
917
+ for shard_id, shard_offset, shard_size in shard_offsets:
918
+ loaded_weight_shard = loaded_weight.narrow(
919
+ param.output_dim, shard_offset, shard_size
920
+ )
921
+ rank_shard_offset = self._get_shard_offset_mapping(shard_id) // block_n
922
+ rank_shard_size = self._get_shard_size_mapping(shard_id) // block_n
923
+ param.load_qkv_weight(
924
+ loaded_weight=loaded_weight_shard,
925
+ num_heads=self.num_kv_head_replicas,
926
+ shard_id=shard_id,
927
+ shard_offset=rank_shard_offset,
928
+ shard_size=rank_shard_size,
929
+ tp_rank=self.tp_rank,
930
+ use_presharded_weights=self.use_presharded_weights,
931
+ )
932
+
897
933
  def weight_loader_v2(
898
934
  self,
899
935
  param: BasevLLMParameter,
@@ -907,6 +943,9 @@ class QKVParallelLinear(ColumnParallelLinear):
907
943
  elif type(param) in (RowvLLMParameter, BasevLLMParameter):
908
944
  param.load_qkv_weight(loaded_weight=loaded_weight)
909
945
  return
946
+ elif isinstance(param, BlockQuantScaleParameter):
947
+ self._load_qkv_block_scale(param, loaded_weight)
948
+ return
910
949
  # TODO: @dsikka - move to parameter.py
911
950
  self._load_fused_module_from_checkpoint(param, loaded_weight)
912
951
  return
@@ -1271,7 +1310,16 @@ class RowParallelLinear(LinearBase):
1271
1310
  shard_size,
1272
1311
  )
1273
1312
  else:
1274
- loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1313
+ # Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
1314
+ end_idx = start_idx + shard_size
1315
+ if end_idx > loaded_weight.shape[input_dim]:
1316
+ loaded_weight = pad_or_narrow_weight(
1317
+ loaded_weight, input_dim, start_idx, shard_size
1318
+ )
1319
+ else:
1320
+ loaded_weight = loaded_weight.narrow(
1321
+ input_dim, start_idx, shard_size
1322
+ )
1275
1323
 
1276
1324
  # Special case for loading scales off disk, which often do not
1277
1325
  # have a shape (such as in the case of AutoFP8).
@@ -35,6 +35,9 @@ from sglang.srt.layers.dp_attention import (
35
35
  get_attention_dp_rank,
36
36
  get_attention_dp_size,
37
37
  get_attention_tp_size,
38
+ get_dp_device,
39
+ get_dp_dtype,
40
+ get_dp_hidden_size,
38
41
  get_global_dp_buffer,
39
42
  get_local_attention_dp_size,
40
43
  set_dp_buffer_len,
@@ -46,10 +49,12 @@ from sglang.srt.model_executor.forward_batch_info import (
46
49
  ForwardBatch,
47
50
  ForwardMode,
48
51
  )
49
- from sglang.srt.utils import dump_to_file, use_intel_amx_backend
52
+ from sglang.srt.utils import dump_to_file, is_npu, use_intel_amx_backend
50
53
 
51
54
  logger = logging.getLogger(__name__)
52
55
 
56
+ _is_npu = is_npu()
57
+
53
58
 
54
59
  @dataclasses.dataclass
55
60
  class LogitsProcessorOutput:
@@ -67,7 +72,10 @@ class LogitsProcessorOutput:
67
72
  next_token_top_logprobs_val: Optional[List] = None
68
73
  next_token_top_logprobs_idx: Optional[List] = None
69
74
  # The logprobs and ids of the requested token ids in output positions. shape: [#seq, n] (n is the number of requested token ids)
70
- next_token_token_ids_logprobs_val: Optional[List] = None
75
+ # Can contain either lists or GPU tensors (for delayed copy optimization in prefill-only requests)
76
+ next_token_token_ids_logprobs_val: Optional[
77
+ List[Union[List[float], torch.Tensor]]
78
+ ] = None
71
79
  next_token_token_ids_logprobs_idx: Optional[List] = None
72
80
 
73
81
  ## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
@@ -180,10 +188,13 @@ class LogitsMetadata:
180
188
  )
181
189
  else:
182
190
  dp_local_start_pos = cumtokens[dp_rank - 1]
183
- dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank]
184
191
 
185
192
  self.dp_local_start_pos = dp_local_start_pos
186
- self.dp_local_num_tokens = dp_local_num_tokens
193
+ self.dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank]
194
+
195
+ hidden_size = get_dp_hidden_size()
196
+ dtype = get_dp_dtype()
197
+ device = get_dp_device()
187
198
 
188
199
  if self.global_num_tokens_for_logprob_cpu is not None:
189
200
  # create a smaller buffer to reduce peak memory usage
@@ -191,10 +202,13 @@ class LogitsMetadata:
191
202
  else:
192
203
  self.global_dp_buffer_len = self.global_dp_buffer_len
193
204
 
194
- set_dp_buffer_len(
195
- self.global_dp_buffer_len,
196
- self.dp_local_num_tokens,
197
- self.global_num_tokens_for_logprob_cpu,
205
+ self.gathered_buffer = torch.empty(
206
+ (
207
+ self.global_dp_buffer_len,
208
+ hidden_size,
209
+ ),
210
+ dtype=dtype,
211
+ device=device,
198
212
  )
199
213
 
200
214
 
@@ -206,6 +220,7 @@ class LogitsProcessor(nn.Module):
206
220
  self.config = config
207
221
  self.logit_scale = logit_scale
208
222
  self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
223
+ self.use_fp32_lm_head = global_server_args_dict["enable_fp32_lm_head"]
209
224
  if self.use_attn_tp_group:
210
225
  self.attn_tp_size = get_attention_tp_size()
211
226
  self.do_tensor_parallel_all_gather = (
@@ -441,13 +456,17 @@ class LogitsProcessor(nn.Module):
441
456
  if self.do_tensor_parallel_all_gather_dp_attn:
442
457
  logits_metadata.compute_dp_attention_metadata()
443
458
  hidden_states, local_hidden_states = (
444
- get_global_dp_buffer(),
459
+ logits_metadata.gathered_buffer,
445
460
  hidden_states,
446
461
  )
447
462
  dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
448
463
 
449
464
  if hasattr(lm_head, "weight"):
450
- if use_intel_amx_backend(lm_head):
465
+ if self.use_fp32_lm_head:
466
+ logits = torch.matmul(
467
+ hidden_states.to(torch.float32), lm_head.weight.to(torch.float32).T
468
+ )
469
+ elif use_intel_amx_backend(lm_head):
451
470
  logits = torch.ops.sgl_kernel.weight_packed_linear(
452
471
  hidden_states.to(lm_head.weight.dtype),
453
472
  lm_head.weight,
@@ -461,7 +480,15 @@ class LogitsProcessor(nn.Module):
461
480
  else:
462
481
  # GGUF models
463
482
  # TODO: use weight_packed_linear for GGUF models
464
- logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias)
483
+ if self.use_fp32_lm_head:
484
+ with torch.cuda.amp.autocast(enabled=False):
485
+ logits = lm_head.quant_method.apply(
486
+ lm_head, hidden_states.to(torch.float32), embedding_bias
487
+ )
488
+ else:
489
+ logits = lm_head.quant_method.apply(
490
+ lm_head, hidden_states, embedding_bias
491
+ )
465
492
 
466
493
  if self.logit_scale is not None:
467
494
  logits.mul_(self.logit_scale)
@@ -517,7 +544,12 @@ class LogitsProcessor(nn.Module):
517
544
  logits = logits[:, : self.config.vocab_size].float()
518
545
 
519
546
  if self.final_logit_softcapping:
520
- fused_softcap(logits, self.final_logit_softcapping)
547
+ if not _is_npu:
548
+ fused_softcap(logits, self.final_logit_softcapping)
549
+ else:
550
+ logits = self.final_logit_softcapping * torch.tanh(
551
+ logits / self.final_logit_softcapping
552
+ )
521
553
 
522
554
  return logits
523
555
 
@@ -1,4 +1,4 @@
1
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
1
+ from sglang.srt.layers.moe.moe_runner import MoeRunner, MoeRunnerConfig
2
2
  from sglang.srt.layers.moe.utils import (
3
3
  DeepEPMode,
4
4
  MoeA2ABackend,
@@ -17,6 +17,7 @@ from sglang.srt.layers.moe.utils import (
17
17
  __all__ = [
18
18
  "DeepEPMode",
19
19
  "MoeA2ABackend",
20
+ "MoeRunner",
20
21
  "MoeRunnerConfig",
21
22
  "MoeRunnerBackend",
22
23
  "initialize_moe_config",
@@ -147,8 +147,8 @@ def cutlass_w4a8_moe(
147
147
  k,
148
148
  )
149
149
 
150
- c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.half)
151
- c2 = torch.zeros((m * topk, k), device=device, dtype=torch.half)
150
+ c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.bfloat16)
151
+ c2 = torch.zeros((m * topk, k), device=device, dtype=torch.bfloat16)
152
152
 
153
153
  cutlass_w4a8_moe_mm(
154
154
  c1,
@@ -166,7 +166,7 @@ def cutlass_w4a8_moe(
166
166
  topk,
167
167
  )
168
168
 
169
- intermediate = torch.empty((m * topk, n), device=device, dtype=torch.half)
169
+ intermediate = torch.empty((m * topk, n), device=device, dtype=torch.bfloat16)
170
170
  silu_and_mul(c1, intermediate)
171
171
 
172
172
  intermediate_q = torch.empty(
@@ -1104,10 +1104,10 @@ def ep_gather(
1104
1104
  input_index: torch.Tensor,
1105
1105
  output_tensor: torch.Tensor,
1106
1106
  ):
1107
- BLOCK_D = 1024 if not is_in_ci() else 128 # block size of quantization
1108
1107
  num_warps = 2
1109
1108
  num_tokens = output_tensor.shape[0]
1110
1109
  hidden_size = input_tensor.shape[1]
1110
+ BLOCK_D = 128 if hidden_size % 1024 != 0 else 1024 # block size of quantization
1111
1111
  assert hidden_size % BLOCK_D == 0
1112
1112
  grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024))
1113
1113
  _fwd_kernel_ep_gather[grid](
@@ -1416,7 +1416,7 @@ def zero_experts_compute_triton(
1416
1416
  zero_expert_scales[zero_expert_mask] = 0.0
1417
1417
 
1418
1418
  normal_expert_mask = expert_indices >= num_experts
1419
- expert_indices[normal_expert_mask] = 0
1419
+ expert_indices[normal_expert_mask] = -1
1420
1420
  expert_scales[normal_expert_mask] = 0.0
1421
1421
 
1422
1422
  output = torch.zeros_like(hidden_states).to(hidden_states.device)