sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -11,6 +11,7 @@ from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
11
11
  ENABLE_JIT_DEEPGEMM,
12
12
  )
13
13
  from sglang.srt.server_args import ServerArgs
14
+ from sglang.srt.utils import get_bool_env_var
14
15
 
15
16
  logger = logging.getLogger(__name__)
16
17
 
@@ -18,6 +19,8 @@ if ENABLE_JIT_DEEPGEMM:
18
19
  import deep_gemm
19
20
  from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
20
21
 
22
+ _SANITY_CHECK = get_bool_env_var("SGLANG_DEEPGEMM_SANITY_CHECK")
23
+
21
24
 
22
25
  # TODO maybe rename these functions
23
26
  def grouped_gemm_nt_f8f8bf16_masked(
@@ -31,6 +34,9 @@ def grouped_gemm_nt_f8f8bf16_masked(
31
34
  _, n, _ = rhs[0].shape
32
35
  kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED
33
36
 
37
+ _sanity_check_input(lhs)
38
+ _sanity_check_input(rhs)
39
+
34
40
  with compile_utils.deep_gemm_execution_hook(
35
41
  expected_m, n, k, num_groups, kernel_type
36
42
  ):
@@ -53,6 +59,9 @@ def grouped_gemm_nt_f8f8bf16_contig(
53
59
  num_groups, n, _ = rhs[0].shape
54
60
  kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
55
61
 
62
+ _sanity_check_input(lhs)
63
+ _sanity_check_input(rhs)
64
+
56
65
  with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
57
66
  deep_gemm.m_grouped_fp8_gemm_nt_contiguous(lhs, rhs, out, m_indices)
58
67
 
@@ -67,6 +76,9 @@ def gemm_nt_f8f8bf16(
67
76
  num_groups = 1
68
77
  kernel_type = compile_utils.DeepGemmKernelType.GEMM_NT_F8F8BF16
69
78
 
79
+ _sanity_check_input(lhs)
80
+ _sanity_check_input(rhs)
81
+
70
82
  with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
71
83
  deep_gemm.fp8_gemm_nt(
72
84
  lhs,
@@ -90,3 +102,18 @@ def configure_deep_gemm_num_sms(num_sms):
90
102
  yield
91
103
  finally:
92
104
  deep_gemm.set_num_sms(original_num_sms)
105
+
106
+
107
+ def _sanity_check_input(x_fp8: Tuple[torch.Tensor, torch.Tensor]):
108
+ if not _SANITY_CHECK:
109
+ return
110
+
111
+ x, x_scale = x_fp8
112
+
113
+ if x_scale.dtype == torch.int:
114
+ return
115
+
116
+ from sglang.srt.layers.quantization.fp8_utils import ceil_to_ue8m0
117
+
118
+ x_scale_ceil = ceil_to_ue8m0(x_scale)
119
+ assert torch.all(x_scale == x_scale_ceil), f"{x_scale=} {x_scale_ceil=}"
@@ -30,6 +30,9 @@ except ImportError:
30
30
 
31
31
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
32
32
  from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
33
+ from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
34
+ from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
35
+ from sglang.srt.layers.moe.token_dispatcher.base import DispatchOutputChecker
33
36
  from sglang.srt.layers.parameter import (
34
37
  BlockQuantScaleParameter,
35
38
  ModelWeightParameter,
@@ -81,7 +84,11 @@ from sglang.srt.utils import (
81
84
  )
82
85
 
83
86
  if TYPE_CHECKING:
84
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
87
+ from sglang.srt.layers.moe.token_dispatcher import (
88
+ CombineInput,
89
+ DispatchOutput,
90
+ StandardDispatchOutput,
91
+ )
85
92
  from sglang.srt.layers.moe.topk import TopKOutput
86
93
  from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
87
94
 
@@ -345,11 +352,14 @@ class Fp8LinearMethod(LinearMethodBase):
345
352
  _is_cpu_amx_available
346
353
  ), "Fp8LinearMethod on CPU requires that CPU has AMX support"
347
354
  _amx_process_weight_after_loading(layer, ["weight"])
355
+ layer.weight_scale_inv = torch.nn.Parameter(
356
+ layer.weight_scale_inv.data, requires_grad=False
357
+ )
348
358
  return
349
359
  else:
350
360
  weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
351
- layer.weight = Parameter(weight, requires_grad=False)
352
- layer.weight_scale_inv = Parameter(weight_scale, requires_grad=False)
361
+ layer.weight.data = weight.data
362
+ layer.weight_scale_inv.data = weight_scale.data
353
363
  else:
354
364
  layer.weight = Parameter(layer.weight.data, requires_grad=False)
355
365
 
@@ -527,7 +537,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
527
537
  layer: Module,
528
538
  num_experts: int,
529
539
  hidden_size: int,
530
- intermediate_size: int,
540
+ intermediate_size_per_partition: int,
531
541
  params_dtype: torch.dtype,
532
542
  **extra_weight_attrs,
533
543
  ):
@@ -543,18 +553,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
543
553
  )
544
554
  # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
545
555
  # Required by column parallel or enabling merged weights
546
- if intermediate_size % block_n != 0:
556
+ if intermediate_size_per_partition % block_n != 0:
547
557
  raise ValueError(
548
558
  f"The output_size of gate's and up's weight = "
549
- f"{intermediate_size} is not divisible by "
559
+ f"{intermediate_size_per_partition} is not divisible by "
550
560
  f"weight quantization block_n = {block_n}."
551
561
  )
552
562
  if tp_size > 1:
553
563
  # Required by row parallel
554
- if intermediate_size % block_k != 0:
564
+ if intermediate_size_per_partition % block_k != 0:
555
565
  raise ValueError(
556
566
  f"The input_size of down's weight = "
557
- f"{intermediate_size} is not divisible by "
567
+ f"{intermediate_size_per_partition} is not divisible by "
558
568
  f"weight quantization block_k = {block_k}."
559
569
  )
560
570
 
@@ -564,7 +574,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
564
574
  w13_weight = torch.nn.Parameter(
565
575
  torch.empty(
566
576
  num_experts,
567
- 2 * intermediate_size,
577
+ 2 * intermediate_size_per_partition,
568
578
  hidden_size // 8,
569
579
  dtype=params_dtype,
570
580
  ),
@@ -572,20 +582,29 @@ class Fp8MoEMethod(FusedMoEMethodBase):
572
582
  )
573
583
  w2_weight = torch.nn.Parameter(
574
584
  torch.empty(
575
- num_experts, hidden_size, intermediate_size // 8, dtype=params_dtype
585
+ num_experts,
586
+ hidden_size,
587
+ intermediate_size_per_partition // 8,
588
+ dtype=params_dtype,
576
589
  ),
577
590
  requires_grad=False,
578
591
  )
579
592
  else:
580
593
  w13_weight = torch.nn.Parameter(
581
594
  torch.empty(
582
- num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
595
+ num_experts,
596
+ 2 * intermediate_size_per_partition,
597
+ hidden_size,
598
+ dtype=params_dtype,
583
599
  ),
584
600
  requires_grad=False,
585
601
  )
586
602
  w2_weight = torch.nn.Parameter(
587
603
  torch.empty(
588
- num_experts, hidden_size, intermediate_size, dtype=params_dtype
604
+ num_experts,
605
+ hidden_size,
606
+ intermediate_size_per_partition,
607
+ dtype=params_dtype,
589
608
  ),
590
609
  requires_grad=False,
591
610
  )
@@ -601,7 +620,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
601
620
  w13_weight_scale = torch.nn.Parameter(
602
621
  torch.ones(
603
622
  num_experts,
604
- 2 * ((intermediate_size + block_n - 1) // block_n),
623
+ 2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
605
624
  (hidden_size + block_k - 1) // block_k,
606
625
  dtype=torch.float32,
607
626
  ),
@@ -611,7 +630,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
611
630
  torch.ones(
612
631
  num_experts,
613
632
  (hidden_size + block_n - 1) // block_n,
614
- (intermediate_size + block_k - 1) // block_k,
633
+ (intermediate_size_per_partition + block_k - 1) // block_k,
615
634
  dtype=torch.float32,
616
635
  ),
617
636
  requires_grad=False,
@@ -619,11 +638,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
619
638
  layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
620
639
  layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
621
640
  assert self.quant_config.activation_scheme == "dynamic"
622
- if (
623
- get_bool_env_var("SGLANG_CUTLASS_MOE")
624
- and self.cutlass_fp8_supported
625
- and (is_sm100_supported() or is_sm90_supported())
626
- ):
641
+ if self.use_cutlass_fused_experts_fp8:
627
642
  self.ab_strides1 = torch.full(
628
643
  (num_experts,),
629
644
  hidden_size,
@@ -632,13 +647,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
632
647
  )
633
648
  self.c_strides1 = torch.full(
634
649
  (num_experts,),
635
- 2 * intermediate_size,
650
+ 2 * intermediate_size_per_partition,
636
651
  device=w13_weight.device,
637
652
  dtype=torch.int64,
638
653
  )
639
654
  self.ab_strides2 = torch.full(
640
655
  (num_experts,),
641
- intermediate_size,
656
+ intermediate_size_per_partition,
642
657
  device=w2_weight.device,
643
658
  dtype=torch.int64,
644
659
  )
@@ -691,7 +706,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
691
706
  if _is_hip: # _use_aiter: TODO: add check back after triton kernel
692
707
  # ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
693
708
  w13_weight_scale1 = torch.nn.Parameter(
694
- torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
709
+ torch.ones(
710
+ num_experts,
711
+ 2 * intermediate_size_per_partition,
712
+ dtype=torch.float32,
713
+ ),
695
714
  requires_grad=False,
696
715
  )
697
716
  w2_weight_scale1 = torch.nn.Parameter(
@@ -984,14 +1003,23 @@ class Fp8MoEMethod(FusedMoEMethodBase):
984
1003
  )
985
1004
  torch.cuda.empty_cache()
986
1005
 
1006
+ def create_moe_runner(
1007
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
1008
+ ):
1009
+ self.moe_runner_config = moe_runner_config
1010
+ self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
1011
+
987
1012
  def apply(
988
1013
  self,
989
1014
  layer: torch.nn.Module,
990
- x: torch.Tensor,
991
- topk_output: TopKOutput,
992
- moe_runner_config: MoeRunnerConfig,
993
- ) -> torch.Tensor:
994
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
1015
+ dispatch_output: DispatchOutput,
1016
+ ) -> CombineInput:
1017
+
1018
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
1019
+
1020
+ x = dispatch_output.hidden_states
1021
+ topk_output = dispatch_output.topk_output
1022
+ moe_runner_config = self.moe_runner_config
995
1023
 
996
1024
  if use_intel_amx_backend(layer):
997
1025
  from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
@@ -1001,7 +1029,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1001
1029
  moe_runner_config.apply_router_weight_on_input, topk_weights, x
1002
1030
  )
1003
1031
 
1004
- return torch.ops.sgl_kernel.fused_experts_cpu(
1032
+ output = torch.ops.sgl_kernel.fused_experts_cpu(
1005
1033
  x,
1006
1034
  layer.w13_weight,
1007
1035
  layer.w2_weight,
@@ -1017,6 +1045,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1017
1045
  None, # a2_scale
1018
1046
  True, # is_vnni
1019
1047
  )
1048
+ return StandardCombineInput(hidden_states=output)
1020
1049
 
1021
1050
  if _is_hip:
1022
1051
  ret = self.maybe_apply_hip_fused_experts(
@@ -1027,7 +1056,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1027
1056
  moe_runner_config.no_combine,
1028
1057
  )
1029
1058
  if ret is not None:
1030
- return ret
1059
+ return StandardCombineInput(hidden_states=ret)
1031
1060
 
1032
1061
  if self.use_cutlass_fused_experts_fp8:
1033
1062
  from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
@@ -1056,17 +1085,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1056
1085
  self.problem_sizes2,
1057
1086
  use_fp8_blockscale=True,
1058
1087
  )
1059
- # Scale by routed_scaling_factor is fused into select_experts.
1060
- return output
1061
- # Expert fusion with FP8 quantization
1062
- return fused_experts(
1063
- x,
1064
- layer.w13_weight,
1065
- layer.w2_weight,
1066
- topk_output=topk_output,
1067
- moe_runner_config=moe_runner_config,
1088
+ return StandardCombineInput(hidden_states=output)
1089
+
1090
+ quant_info = TritonMoeQuantInfo(
1091
+ w13_weight=layer.w13_weight,
1092
+ w2_weight=layer.w2_weight,
1068
1093
  use_fp8_w8a8=True,
1069
- w1_scale=(
1094
+ w13_scale=(
1070
1095
  layer.w13_weight_scale_inv
1071
1096
  if self.block_quant
1072
1097
  else layer.w13_weight_scale
@@ -1074,20 +1099,22 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1074
1099
  w2_scale=(
1075
1100
  layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
1076
1101
  ),
1077
- a1_scale=layer.w13_input_scale,
1102
+ a13_scale=layer.w13_input_scale,
1078
1103
  a2_scale=layer.w2_input_scale,
1079
1104
  block_shape=self.quant_config.weight_block_size,
1080
1105
  )
1106
+ return self.runner.run(dispatch_output, quant_info)
1081
1107
 
1082
1108
  def apply_with_router_logits(
1083
1109
  self,
1084
1110
  layer: torch.nn.Module,
1085
- x: torch.Tensor,
1086
- topk_output: TopKOutput,
1087
- moe_runner_config: MoeRunnerConfig,
1111
+ dispatch_output: StandardDispatchOutput,
1088
1112
  ) -> torch.Tensor:
1089
- activation = moe_runner_config.activation
1090
- routed_scaling_factor = moe_runner_config.routed_scaling_factor
1113
+ x = dispatch_output.hidden_states
1114
+ topk_output = dispatch_output.topk_output
1115
+
1116
+ activation = self.moe_runner_config.activation
1117
+ routed_scaling_factor = self.moe_runner_config.routed_scaling_factor
1091
1118
 
1092
1119
  from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
1093
1120
 
@@ -1108,10 +1135,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1108
1135
  and topk_config.topk_group is not None
1109
1136
  ), "Current trtllm_fp8_block_scale_moe kernel does not support these two arguments as None"
1110
1137
 
1111
- if topk_config.correction_bias is None:
1112
- correction_bias = topk_config.correction_bias.to(x.dtype)
1113
- else:
1114
- correction_bias = None
1138
+ correction_bias = (
1139
+ None
1140
+ if topk_config.correction_bias is None
1141
+ else topk_config.correction_bias.to(x.dtype)
1142
+ )
1143
+
1115
1144
  return trtllm_fp8_block_scale_moe(
1116
1145
  routing_logits=router_logits.to(torch.float32),
1117
1146
  routing_bias=correction_bias,
@@ -2,6 +2,7 @@ from typing import Callable, List, Optional, Tuple
2
2
 
3
3
  import torch
4
4
 
5
+ from sglang.srt import offloader
5
6
  from sglang.srt.layers.quantization import deep_gemm_wrapper
6
7
  from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
7
8
  from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
@@ -45,7 +46,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
45
46
 
46
47
  if _use_aiter:
47
48
  import aiter
48
- from aiter import gemm_a8w8_blockscale, get_hip_quant
49
+ from aiter import gemm_a8w8_blockscale, gemm_a8w8_bpreshuffle, get_hip_quant
49
50
 
50
51
  aiter_per1x128_quant = get_hip_quant(aiter.QuantType.per_1x128)
51
52
 
@@ -248,11 +249,6 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
248
249
  scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
249
250
  )
250
251
 
251
- # NOTE(alcanderian): Useless when scale is packed to int32
252
- # if get_bool_env_var("SGLANG_W8A8_DEEPGEMM_SANITY_CHECK_UE8M0"):
253
- # _check_ue8m0("x_scale", x_scale)
254
- # _check_ue8m0("weight_scale", ws)
255
-
256
252
  output = w8a8_block_fp8_matmul_deepgemm(
257
253
  q_input, weight, x_scale, weight_scale, block_size, output_dtype=output_dtype
258
254
  )
@@ -261,11 +257,6 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
261
257
  return output.to(dtype=output_dtype).view(*output_shape)
262
258
 
263
259
 
264
- def _check_ue8m0(name, x):
265
- x_ceil = ceil_to_ue8m0(x)
266
- assert torch.all(x == x_ceil), f"{name=} {x=} {x_ceil=}"
267
-
268
-
269
260
  def aiter_w8a8_block_fp8_linear(
270
261
  input: torch.Tensor,
271
262
  weight: torch.Tensor,
@@ -427,10 +418,14 @@ def block_quant_dequant(
427
418
  def requant_weight_ue8m0_inplace(weight, weight_scale_inv, weight_block_size):
428
419
  assert isinstance(weight, torch.nn.Parameter)
429
420
  assert isinstance(weight_scale_inv, torch.nn.Parameter)
430
- weight.data, weight_scale_inv.data = _requant_weight_ue8m0(
431
- weight, weight_scale_inv, weight_block_size
421
+
422
+ new_weight, new_weight_scale_inv = _requant_weight_ue8m0(
423
+ weight.to(weight_scale_inv.device), weight_scale_inv, weight_block_size
432
424
  )
433
425
 
426
+ offloader.update_param(weight, new_weight)
427
+ weight_scale_inv.data = new_weight_scale_inv
428
+
434
429
 
435
430
  def _requant_weight_ue8m0(
436
431
  weight: torch.Tensor,
@@ -652,25 +647,49 @@ def apply_fp8_linear(
652
647
  use_per_token_if_dynamic
653
648
  and not per_tensor_weights
654
649
  and not per_tensor_activations
655
- and USE_ROWWISE_TORCH_SCALED_MM
650
+ and (USE_ROWWISE_TORCH_SCALED_MM or _use_aiter)
656
651
  ):
657
- # For now validated on ROCm platform
658
- # fp8 rowwise scaling in torch._scaled_mm is introduced in
659
- # https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
660
- # and ROCm 6.3, which only exists in torch 2.7 and above.
661
- # For CUDA platform please validate if the
662
- # torch._scaled_mm support rowwise scaled GEMM
663
- # Fused GEMM_DQ Rowwise GEMM
664
- output = torch._scaled_mm(
665
- qinput,
666
- weight,
667
- out_dtype=input.dtype,
668
- scale_a=x_scale,
669
- scale_b=weight_scale.t(),
670
- bias=bias,
671
- )
672
- return _process_scaled_mm_output(output, input_2d.shape, output_shape)
673
-
652
+ # into this sector means use dynamic per-token-per-channel quant
653
+ # per-token scale quant for input matrix, every row(one token) have one scale factor
654
+ # per-channel scale quant for weight matrix, every col(one channel) have one scale factor
655
+ if _use_aiter:
656
+ # gemm_a8w8_bpreshuffle(XQ, WQ, x_scale, w_scale, dtype)
657
+ # XQ -> input tensor, shape = (m, k)
658
+ # WQ -> weight tensor, shape = (n, k), with preshuffe get better perf
659
+ # x_scale -> input scale tensor, shape = (m, 1)
660
+ # w_scale -> weight scale tensor, shape = (n ,1)
661
+ # dtype -> output dtype
662
+ output = gemm_a8w8_bpreshuffle(
663
+ XQ=qinput,
664
+ WQ=weight,
665
+ x_scale=x_scale,
666
+ w_scale=weight_scale,
667
+ dtype=input.dtype,
668
+ )
669
+ if bias is not None:
670
+ output += bias
671
+ return _process_scaled_mm_output(
672
+ output, input_2d.shape, [*input.shape[:-1], weight.shape[0]]
673
+ )
674
+ else:
675
+ # For now validated on ROCm platform
676
+ # fp8 rowwise scaling in torch._scaled_mm is introduced in
677
+ # https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
678
+ # and ROCm 6.3, which only exists in torch 2.7 and above.
679
+ # For CUDA platform please validate if the
680
+ # torch._scaled_mm support rowwise scaled GEMM
681
+ # Fused GEMM_DQ Rowwise GEMM
682
+ output = torch._scaled_mm(
683
+ qinput,
684
+ weight,
685
+ out_dtype=input.dtype,
686
+ scale_a=x_scale,
687
+ scale_b=weight_scale.t(),
688
+ bias=bias,
689
+ )
690
+ return _process_scaled_mm_output(
691
+ output, input_2d.shape, output_shape
692
+ )
674
693
  else:
675
694
  # Fallback for channelwise case, where we use unfused DQ
676
695
  # due to limitations with scaled_mm
@@ -713,7 +732,7 @@ def apply_fp8_linear(
713
732
  # final solution should be: 1. add support to per-tensor activation scaling.
714
733
  # 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308)
715
734
  if _is_hip and weight_scale.numel() == 1:
716
- qinput, x_scale = ops.scaled_fp8_quant(
735
+ qinput, x_scale = scaled_fp8_quant(
717
736
  input_2d,
718
737
  input_scale,
719
738
  use_per_token_if_dynamic=use_per_token_if_dynamic,
@@ -45,7 +45,10 @@ from sglang.srt.layers.quantization.utils import (
45
45
 
46
46
  if TYPE_CHECKING:
47
47
  from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
48
- from sglang.srt.layers.moe.topk import TopKOutput
48
+ from sglang.srt.layers.moe.token_dispatcher import (
49
+ StandardDispatchOutput,
50
+ CombineInput,
51
+ )
49
52
 
50
53
  from sglang.srt.utils import is_cuda
51
54
 
@@ -838,19 +841,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
838
841
  from sglang.srt.layers.linear import set_weight_attrs
839
842
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
840
843
 
841
- intermediate_size = extra_weight_attrs.pop("intermediate_size")
842
-
843
- self.is_k_full = (not self.quant_config.desc_act) or (
844
- intermediate_size_per_partition == intermediate_size
845
- )
844
+ self.is_k_full = (not self.quant_config.desc_act) or layer.moe_tp_size == 1
846
845
 
847
846
  if self.quant_config.group_size != -1:
848
847
  scales_size13 = hidden_size // self.quant_config.group_size
849
- w2_scales_size = (
850
- intermediate_size
851
- if self.quant_config.desc_act
852
- else intermediate_size_per_partition
853
- )
848
+ if self.quant_config.desc_act:
849
+ w2_scales_size = intermediate_size_per_partition
850
+ else:
851
+ w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size
854
852
  scales_size2 = w2_scales_size // self.quant_config.group_size
855
853
  strategy = FusedMoeWeightScaleSupported.GROUP.value
856
854
  else:
@@ -1052,17 +1050,26 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
1052
1050
  )
1053
1051
  replace_parameter(layer, "w2_scales", marlin_w2_scales)
1054
1052
 
1053
+ def create_moe_runner(
1054
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
1055
+ ):
1056
+ self.moe_runner_config = moe_runner_config
1057
+
1055
1058
  def apply(
1056
1059
  self,
1057
1060
  layer: torch.nn.Module,
1058
- x: torch.Tensor,
1059
- topk_output: TopKOutput,
1060
- moe_runner_config: MoeRunnerConfig,
1061
- ) -> torch.Tensor:
1061
+ dispatch_output: StandardDispatchOutput,
1062
+ ) -> CombineInput:
1063
+
1064
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
1065
+
1066
+ x = dispatch_output.hidden_states
1067
+ topk_output = dispatch_output.topk_output
1068
+
1062
1069
  # Delay the import to avoid circular dependency
1063
1070
 
1064
1071
  assert (
1065
- moe_runner_config.activation == "silu"
1072
+ self.moe_runner_config.activation == "silu"
1066
1073
  ), "Only SiLU activation is supported."
1067
1074
 
1068
1075
  # The input must currently be float16
@@ -1071,7 +1078,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
1071
1078
 
1072
1079
  topk_weights, topk_ids, router_logits = topk_output
1073
1080
 
1074
- return fused_marlin_moe(
1081
+ output = fused_marlin_moe(
1075
1082
  x,
1076
1083
  layer.w13_qweight,
1077
1084
  layer.w2_qweight,
@@ -1087,3 +1094,4 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
1087
1094
  num_bits=self.quant_config.weight_bits,
1088
1095
  is_k_full=self.is_k_full,
1089
1096
  ).to(orig_dtype)
1097
+ return StandardCombineInput(hidden_states=output)