sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -34,7 +34,10 @@ from sglang.srt.layers.quantization.utils import get_scalar_types, replace_param
34
34
 
35
35
  if TYPE_CHECKING:
36
36
  from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
37
- from sglang.srt.layers.moe.topk import StandardTopKOutput
37
+ from sglang.srt.layers.moe.token_dispatcher import (
38
+ StandardDispatchOutput,
39
+ CombineInput,
40
+ )
38
41
 
39
42
  from sglang.srt.utils import is_cuda, is_hip
40
43
 
@@ -736,24 +739,32 @@ class AWQMoEMethod(FusedMoEMethodBase):
736
739
  )
737
740
  replace_parameter(layer, "w2_qzeros", marlin_w2_zp)
738
741
 
742
+ def create_moe_runner(
743
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
744
+ ):
745
+ self.moe_runner_config = moe_runner_config
746
+
739
747
  def apply(
740
748
  self,
741
749
  layer: torch.nn.Module,
742
- x: torch.Tensor,
743
- topk_output: StandardTopKOutput,
744
- moe_runner_config: MoeRunnerConfig,
745
- ) -> torch.Tensor:
750
+ dispatch_output: StandardDispatchOutput,
751
+ ) -> CombineInput:
752
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
753
+
746
754
  assert (
747
- moe_runner_config.activation == "silu"
755
+ self.moe_runner_config.activation == "silu"
748
756
  ), "Only SiLU activation is supported."
749
757
 
750
758
  # The input must currently be float16
759
+ x = dispatch_output.hidden_states
760
+ topk_output = dispatch_output.topk_output
761
+
751
762
  orig_dtype = x.dtype
752
763
  x = x.half()
753
764
 
754
765
  topk_weights, topk_ids, router_logits = topk_output
755
766
 
756
- return fused_marlin_moe(
767
+ output = fused_marlin_moe(
757
768
  x,
758
769
  layer.w13_qweight,
759
770
  layer.w2_qweight,
@@ -768,3 +779,4 @@ class AWQMoEMethod(FusedMoEMethodBase):
768
779
  w2_zeros=layer.w2_qzeros,
769
780
  num_bits=self.quant_config.weight_bits,
770
781
  ).to(orig_dtype)
782
+ return StandardCombineInput(hidden_states=output)
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
 
4
4
  import inspect
5
5
  from abc import ABC, abstractmethod
6
+ from dataclasses import dataclass
6
7
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
7
8
 
8
9
  import torch
@@ -10,7 +11,7 @@ from torch import nn
10
11
 
11
12
  if TYPE_CHECKING:
12
13
  from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
13
- from sglang.srt.layers.moe.topk import TopKOutput
14
+ from sglang.srt.layers.moe.token_dispatcher import CombineInput, DispatchOutput
14
15
 
15
16
 
16
17
  class QuantizeMethodBase(ABC):
@@ -89,20 +90,24 @@ class FusedMoEMethodBase(QuantizeMethodBase):
89
90
  layer: torch.nn.Module,
90
91
  num_experts: int,
91
92
  hidden_size: int,
92
- intermediate_size: int,
93
+ intermediate_size_per_partition: int,
93
94
  params_dtype: torch.dtype,
94
95
  **extra_weight_attrs,
95
96
  ):
96
97
  raise NotImplementedError
97
98
 
99
+ @abstractmethod
100
+ def create_moe_runner(
101
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
102
+ ):
103
+ raise NotImplementedError
104
+
98
105
  @abstractmethod
99
106
  def apply(
100
107
  self,
101
108
  layer: torch.nn.Module,
102
- x: torch.Tensor,
103
- topk_output: TopKOutput,
104
- moe_runner_config: MoeRunnerConfig,
105
- ) -> torch.Tensor:
109
+ dispatch_output: DispatchOutput,
110
+ ) -> CombineInput:
106
111
  raise NotImplementedError
107
112
 
108
113
 
@@ -9,6 +9,8 @@ import torch
9
9
  from torch.nn import Module
10
10
 
11
11
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
12
+ from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
13
+ from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
12
14
  from sglang.srt.layers.parameter import BlockQuantScaleParameter, ModelWeightParameter
13
15
  from sglang.srt.layers.quantization.base_config import (
14
16
  FusedMoEMethodBase,
@@ -22,8 +24,10 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped
22
24
  from sglang.srt.utils import set_weight_attrs
23
25
 
24
26
  if TYPE_CHECKING:
25
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
26
- from sglang.srt.layers.moe.topk import TopKOutput
27
+ from sglang.srt.layers.moe.token_dispatcher import (
28
+ CombineInput,
29
+ StandardDispatchOutput,
30
+ )
27
31
 
28
32
  ACTIVATION_SCHEMES = ["static", "dynamic"]
29
33
 
@@ -257,7 +261,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
257
261
  layer: Module,
258
262
  num_experts: int,
259
263
  hidden_size: int,
260
- intermediate_size: int,
264
+ intermediate_size_per_partition: int,
261
265
  params_dtype: torch.dtype,
262
266
  **extra_weight_attrs,
263
267
  ):
@@ -273,25 +277,28 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
273
277
  )
274
278
  # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
275
279
  # Required by column parallel or enabling merged weights
276
- if intermediate_size % block_n != 0:
280
+ if intermediate_size_per_partition % block_n != 0:
277
281
  raise ValueError(
278
282
  f"The output_size of gate's and up's weight = "
279
- f"{intermediate_size} is not divisible by "
283
+ f"{intermediate_size_per_partition} is not divisible by "
280
284
  f"weight quantization block_n = {block_n}."
281
285
  )
282
286
  if tp_size > 1:
283
287
  # Required by row parallel
284
- if intermediate_size % block_k != 0:
288
+ if intermediate_size_per_partition % block_k != 0:
285
289
  raise ValueError(
286
290
  f"The input_size of down's weight = "
287
- f"{intermediate_size} is not divisible by "
291
+ f"{intermediate_size_per_partition} is not divisible by "
288
292
  f"weight quantization block_k = {block_k}."
289
293
  )
290
294
 
291
295
  # WEIGHTS
292
296
  w13_weight = torch.nn.Parameter(
293
297
  torch.empty(
294
- num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
298
+ num_experts,
299
+ 2 * intermediate_size_per_partition,
300
+ hidden_size,
301
+ dtype=params_dtype,
295
302
  ),
296
303
  requires_grad=False,
297
304
  )
@@ -300,7 +307,10 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
300
307
 
301
308
  w2_weight = torch.nn.Parameter(
302
309
  torch.empty(
303
- num_experts, hidden_size, intermediate_size, dtype=params_dtype
310
+ num_experts,
311
+ hidden_size,
312
+ intermediate_size_per_partition,
313
+ dtype=params_dtype,
304
314
  ),
305
315
  requires_grad=False,
306
316
  )
@@ -311,7 +321,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
311
321
  w13_weight_scale = torch.nn.Parameter(
312
322
  torch.ones(
313
323
  num_experts,
314
- 2 * ((intermediate_size + block_n - 1) // block_n),
324
+ 2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
315
325
  (hidden_size + block_k - 1) // block_k,
316
326
  dtype=torch.float32,
317
327
  ),
@@ -321,7 +331,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
321
331
  torch.ones(
322
332
  num_experts,
323
333
  (hidden_size + block_n - 1) // block_n,
324
- (intermediate_size + block_k - 1) // block_k,
334
+ (intermediate_size_per_partition + block_k - 1) // block_k,
325
335
  dtype=torch.float32,
326
336
  ),
327
337
  requires_grad=False,
@@ -344,26 +354,27 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
344
354
  # Block quant doesn't need to process weights after loading
345
355
  return
346
356
 
357
+ def create_moe_runner(
358
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
359
+ ):
360
+ self.moe_runner_config = moe_runner_config
361
+ self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
362
+
347
363
  def apply(
348
364
  self,
349
365
  layer: torch.nn.Module,
350
- x: torch.Tensor,
351
- topk_output: TopKOutput,
352
- moe_runner_config: MoeRunnerConfig,
353
- ) -> torch.Tensor:
354
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
355
-
356
- # Expert fusion with INT8 quantization
357
- return fused_experts(
358
- x,
359
- layer.w13_weight,
360
- layer.w2_weight,
361
- topk_output=topk_output,
362
- moe_runner_config=moe_runner_config,
366
+ dispatch_output: StandardDispatchOutput,
367
+ ) -> CombineInput:
368
+
369
+ quant_info = TritonMoeQuantInfo(
370
+ w13_weight=layer.w13_weight,
371
+ w2_weight=layer.w2_weight,
363
372
  use_int8_w8a8=True,
364
- w1_scale=(layer.w13_weight_scale_inv),
365
- w2_scale=(layer.w2_weight_scale_inv),
366
- a1_scale=layer.w13_input_scale,
373
+ w13_scale=layer.w13_weight_scale_inv,
374
+ w2_scale=layer.w2_weight_scale_inv,
375
+ a13_scale=layer.w13_input_scale,
367
376
  a2_scale=layer.w2_input_scale,
368
377
  block_shape=self.quant_config.weight_block_size,
369
378
  )
379
+
380
+ return self.runner.run(dispatch_output, quant_info)
@@ -30,6 +30,7 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe im
30
30
  from sglang.srt.layers.quantization.compressed_tensors.schemes import (
31
31
  CompressedTensorsScheme,
32
32
  CompressedTensorsW8A8Fp8,
33
+ CompressedTensorsW8A8Int8,
33
34
  CompressedTensorsW8A16Fp8,
34
35
  )
35
36
  from sglang.srt.layers.quantization.compressed_tensors.utils import (
@@ -11,6 +11,8 @@ import torch
11
11
  from compressed_tensors import CompressionFormat
12
12
  from compressed_tensors.quantization import QuantizationStrategy
13
13
 
14
+ from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
15
+ from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
14
16
  from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
15
17
  from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
16
18
  from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
@@ -30,8 +32,10 @@ from sglang.srt.utils import (
30
32
 
31
33
  if TYPE_CHECKING:
32
34
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
33
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
34
- from sglang.srt.layers.moe.topk import TopKOutput
35
+ from sglang.srt.layers.moe.token_dispatcher import (
36
+ CombineInput,
37
+ StandardDispatchOutput,
38
+ )
35
39
  from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
36
40
  CompressedTensorsConfig,
37
41
  )
@@ -293,14 +297,24 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
293
297
  )
294
298
  torch.cuda.empty_cache()
295
299
 
300
+ def create_moe_runner(
301
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
302
+ ):
303
+ self.moe_runner_config = moe_runner_config
304
+ self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
305
+
296
306
  def apply(
297
307
  self,
298
308
  layer: torch.nn.Module,
299
- x: torch.Tensor,
300
- topk_output: TopKOutput,
301
- moe_runner_config: MoeRunnerConfig,
302
- ) -> torch.Tensor:
303
- from sglang.srt.layers.moe.fused_moe_triton import fused_experts
309
+ dispatch_output: StandardDispatchOutput,
310
+ ) -> CombineInput:
311
+
312
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
313
+
314
+ x = dispatch_output.hidden_states
315
+ topk_output = dispatch_output.topk_output
316
+
317
+ moe_runner_config = self.moe_runner_config
304
318
 
305
319
  if (
306
320
  _use_aiter
@@ -308,7 +322,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
308
322
  and moe_runner_config.apply_router_weight_on_input
309
323
  ):
310
324
  topk_weights, topk_ids, _ = topk_output
311
- return rocm_fused_experts_tkw1(
325
+ output = rocm_fused_experts_tkw1(
312
326
  hidden_states=x,
313
327
  w1=layer.w13_weight,
314
328
  w2=layer.w2_weight,
@@ -324,21 +338,20 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
324
338
  a1_scale=layer.w13_input_scale,
325
339
  a2_scale=layer.w2_input_scale,
326
340
  )
341
+ return StandardCombineInput(hidden_states=output)
327
342
  else:
328
- return fused_experts(
329
- x,
330
- layer.w13_weight,
331
- layer.w2_weight,
332
- topk_output=topk_output,
333
- moe_runner_config=moe_runner_config,
343
+ quant_info = TritonMoeQuantInfo(
344
+ w13_weight=layer.w13_weight,
345
+ w2_weight=layer.w2_weight,
334
346
  use_fp8_w8a8=True,
335
347
  per_channel_quant=self.weight_quant.strategy
336
348
  == QuantizationStrategy.CHANNEL,
337
- w1_scale=layer.w13_weight_scale,
349
+ w13_scale=layer.w13_weight_scale,
338
350
  w2_scale=layer.w2_weight_scale,
339
- a1_scale=layer.w13_input_scale,
351
+ a13_scale=layer.w13_input_scale,
340
352
  a2_scale=layer.w2_input_scale,
341
353
  )
354
+ return self.runner.run(dispatch_output, quant_info)
342
355
 
343
356
 
344
357
  class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
@@ -380,8 +393,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
380
393
  params_dtype == torch.float16
381
394
  ), "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501
382
395
 
383
- intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full")
384
-
385
396
  # Will transpose the loaded weight along the
386
397
  # intermediate and hidden dim sizes. Will
387
398
  # shard for TP along the transposed dims
@@ -415,13 +426,13 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
415
426
  # In the case where we have actorder/g_idx,
416
427
  # we do not partition the w2 scales
417
428
  load_full_w2 = self.actorder and self.group_size != -1
418
- w2_scales_size = (
419
- intermediate_size_full if load_full_w2 else intermediate_size_per_partition
420
- )
421
429
 
422
- self.is_k_full = (not self.actorder) or (
423
- intermediate_size_per_partition == intermediate_size_full
424
- )
430
+ if load_full_w2:
431
+ w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size
432
+ else:
433
+ w2_scales_size = intermediate_size_per_partition
434
+
435
+ self.is_k_full = (not self.actorder) or layer.moe_tp_size == 1
425
436
 
426
437
  if self.strategy == "channel":
427
438
  num_groups_w2 = num_groups_w13 = 1
@@ -640,21 +651,29 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
640
651
  )
641
652
  replace_tensor("w2_weight_scale", marlin_w2_scales)
642
653
 
654
+ def create_moe_runner(
655
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
656
+ ):
657
+ self.moe_runner_config = moe_runner_config
658
+
643
659
  def apply(
644
660
  self,
645
661
  layer: torch.nn.Module,
646
- x: torch.Tensor,
647
- topk_output: TopKOutput,
648
- moe_runner_config: MoeRunnerConfig,
649
- ) -> torch.Tensor:
662
+ dispatch_output: StandardDispatchOutput,
663
+ ) -> CombineInput:
664
+
665
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
650
666
 
651
667
  assert (
652
- moe_runner_config.activation == "silu"
668
+ self.moe_runner_config.activation == "silu"
653
669
  ), "Only SiLU activation is supported."
654
670
 
671
+ x = dispatch_output.hidden_states
672
+ topk_output = dispatch_output.topk_output
673
+
655
674
  topk_weights, topk_ids, router_logits = topk_output
656
675
 
657
- return torch.ops.vllm.fused_marlin_moe(
676
+ output = torch.ops.vllm.fused_marlin_moe(
658
677
  x,
659
678
  layer.w13_weight_packed,
660
679
  layer.w2_weight_packed,
@@ -670,3 +689,4 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
670
689
  num_bits=self.num_bits,
671
690
  is_k_full=self.is_k_full,
672
691
  )
692
+ return StandardCombineInput(hidden_states=output)
@@ -2,10 +2,12 @@
2
2
 
3
3
  from .compressed_tensors_scheme import CompressedTensorsScheme
4
4
  from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
5
+ from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
5
6
  from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
6
7
 
7
8
  __all__ = [
8
9
  "CompressedTensorsScheme",
9
10
  "CompressedTensorsW8A8Fp8",
10
11
  "CompressedTensorsW8A16Fp8",
12
+ "CompressedTensorsW8A8Int8",
11
13
  ]
@@ -21,9 +21,15 @@ from sglang.srt.layers.quantization.fp8_utils import (
21
21
  normalize_e4m3fn_to_e4m3fnuz,
22
22
  )
23
23
  from sglang.srt.layers.quantization.utils import requantize_with_max_scale
24
+ from sglang.srt.utils import get_bool_env_var, is_hip
24
25
 
25
26
  __all__ = ["CompressedTensorsW8A8Fp8"]
26
27
 
28
+ _is_hip = is_hip()
29
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
30
+ if _use_aiter:
31
+ from aiter.ops.shuffle import shuffle_weight
32
+
27
33
 
28
34
  class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
29
35
 
@@ -76,7 +82,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
76
82
  else:
77
83
  weight_scale = layer.weight_scale.data
78
84
 
79
- layer.weight = Parameter(weight.t(), requires_grad=False)
85
+ if _use_aiter:
86
+ layer.weight = Parameter(
87
+ shuffle_weight(weight, (16, 16)), requires_grad=False
88
+ )
89
+ else:
90
+ layer.weight = Parameter(weight.t(), requires_grad=False)
91
+
80
92
  # required by torch.compile to be torch.nn.Parameter
81
93
  layer.weight_scale = Parameter(weight_scale, requires_grad=False)
82
94
 
@@ -0,0 +1,173 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Callable, Optional
5
+
6
+ import torch
7
+ from compressed_tensors.quantization import QuantizationStrategy
8
+ from torch.nn import Parameter
9
+
10
+ from sglang.srt.layers.parameter import (
11
+ ChannelQuantScaleParameter,
12
+ ModelWeightParameter,
13
+ PerTensorScaleParameter,
14
+ )
15
+ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
16
+ CompressedTensorsScheme,
17
+ )
18
+ from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
19
+ from sglang.srt.layers.quantization.utils import requantize_with_max_scale
20
+ from sglang.srt.utils import is_cuda
21
+
22
+ _is_cuda = is_cuda()
23
+ if _is_cuda:
24
+ from sgl_kernel import int8_scaled_mm
25
+
26
+
27
+ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
28
+
29
+ def __init__(
30
+ self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool
31
+ ):
32
+ self.strategy = strategy
33
+ self.is_static_input_scheme = is_static_input_scheme
34
+ self.input_symmetric = input_symmetric
35
+
36
+ @classmethod
37
+ def get_min_capability(cls) -> int:
38
+ # lovelace and up
39
+ return 89
40
+
41
+ def process_weights_after_loading(self, layer) -> None:
42
+ # If per tensor, when we have a fused module (e.g. QKV) with per
43
+ # tensor scales (thus N scales being passed to the kernel),
44
+ # requantize so we can always run per channel
45
+ if self.strategy == QuantizationStrategy.TENSOR:
46
+ max_w_scale, weight = requantize_with_max_scale(
47
+ weight=layer.weight,
48
+ weight_scale=layer.weight_scale,
49
+ logical_widths=layer.logical_widths,
50
+ )
51
+
52
+ layer.weight = Parameter(weight.t(), requires_grad=False)
53
+ layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
54
+
55
+ # If channelwise, scales are already lined up, so just transpose.
56
+ elif self.strategy == QuantizationStrategy.CHANNEL:
57
+ weight = layer.weight
58
+ weight_scale = layer.weight_scale.data
59
+
60
+ layer.weight = Parameter(weight.t(), requires_grad=False)
61
+ # required by torch.compile to be torch.nn.Parameter
62
+ layer.weight_scale = Parameter(weight_scale, requires_grad=False)
63
+
64
+ else:
65
+ raise ValueError(f"Unknown quantization strategy {self.strategy}")
66
+
67
+ # INPUT SCALE
68
+ if self.is_static_input_scheme and hasattr(layer, "input_scale"):
69
+ if self.input_symmetric:
70
+ layer.input_scale = Parameter(
71
+ layer.input_scale.max(), requires_grad=False
72
+ )
73
+ else:
74
+ input_scale = layer.input_scale
75
+ input_zero_point = layer.input_zero_point
76
+
77
+ # reconstruct the ranges
78
+ int8_traits = torch.iinfo(torch.int8)
79
+ azps = input_zero_point.to(dtype=torch.int32)
80
+ range_max = (input_scale * (int8_traits.max - azps)).max()
81
+ range_min = (input_scale * (int8_traits.min - azps)).min()
82
+
83
+ scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
84
+
85
+ # AZP loaded as int8 but used as int32
86
+ azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
87
+
88
+ layer.input_scale = Parameter(scale, requires_grad=False)
89
+ layer.input_zero_point = Parameter(azp, requires_grad=False)
90
+ else:
91
+ layer.input_scale = None
92
+ layer.input_zero_point = None
93
+
94
+ # azp_adj is the AZP adjustment term, used to account for weights.
95
+ # It does not depend on scales or azp, so it is the same for
96
+ # static and dynamic quantization.
97
+ # For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
98
+ # https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
99
+ if not self.input_symmetric:
100
+ weight = layer.weight
101
+ azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
102
+ if self.is_static_input_scheme:
103
+ # cutlass_w8a8 requires azp to be folded into azp_adj
104
+ # in the per-tensor case
105
+ azp_adj = layer.input_zero_point * azp_adj
106
+ layer.azp_adj = Parameter(azp_adj, requires_grad=False)
107
+ else:
108
+ layer.azp_adj = None
109
+
110
+ def create_weights(
111
+ self,
112
+ layer: torch.nn.Module,
113
+ output_partition_sizes: list[int],
114
+ input_size_per_partition: int,
115
+ params_dtype: torch.dtype,
116
+ weight_loader: Callable,
117
+ **kwargs,
118
+ ):
119
+ output_size_per_partition = sum(output_partition_sizes)
120
+ layer.logical_widths = output_partition_sizes
121
+
122
+ # WEIGHT
123
+ weight = ModelWeightParameter(
124
+ data=torch.empty(
125
+ output_size_per_partition, input_size_per_partition, dtype=torch.int8
126
+ ),
127
+ input_dim=1,
128
+ output_dim=0,
129
+ weight_loader=weight_loader,
130
+ )
131
+
132
+ layer.register_parameter("weight", weight)
133
+
134
+ # WEIGHT SCALE
135
+ if self.strategy == QuantizationStrategy.CHANNEL:
136
+ weight_scale = ChannelQuantScaleParameter(
137
+ data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
138
+ output_dim=0,
139
+ weight_loader=weight_loader,
140
+ )
141
+ else:
142
+ assert self.strategy == QuantizationStrategy.TENSOR
143
+ weight_scale = PerTensorScaleParameter(
144
+ data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
145
+ weight_loader=weight_loader,
146
+ )
147
+ layer.register_parameter("weight_scale", weight_scale)
148
+
149
+ # INPUT SCALE
150
+ if self.is_static_input_scheme:
151
+ input_scale = PerTensorScaleParameter(
152
+ data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
153
+ )
154
+ layer.register_parameter("input_scale", input_scale)
155
+
156
+ if not self.input_symmetric:
157
+ # Note: compressed-tensors stores the zp using the same dtype
158
+ # as the weights
159
+ # AZP loaded as int8 but used as int32
160
+ input_zero_point = PerTensorScaleParameter(
161
+ data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
162
+ )
163
+ layer.register_parameter("input_zero_point", input_zero_point)
164
+
165
+ def apply_weights(
166
+ self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
167
+ ) -> torch.Tensor:
168
+ # TODO: add cutlass_scaled_mm_azp support
169
+ x_q, x_scale = per_token_quant_int8(x)
170
+
171
+ return int8_scaled_mm(
172
+ x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
173
+ )
@@ -1,8 +1,6 @@
1
1
  import logging
2
2
 
3
- import torch
4
-
5
- from sglang.srt.utils import get_bool_env_var, get_device_sm
3
+ from sglang.srt.utils import get_bool_env_var, get_device_sm, is_blackwell
6
4
 
7
5
  logger = logging.getLogger(__name__)
8
6
 
@@ -15,18 +13,12 @@ def _compute_enable_deep_gemm():
15
13
  try:
16
14
  import deep_gemm
17
15
  except ImportError:
18
- logger.warning("Failed to import deep_gemm, disable ENABLE_JIT_DEEPGEMM.")
19
16
  return False
20
17
 
21
18
  return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true")
22
19
 
23
20
 
24
- def _is_blackwell_arch() -> bool:
25
- major, minor = torch.cuda.get_device_capability(torch.cuda.current_device())
26
- return major == 10
27
-
28
-
29
21
  ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
30
22
 
31
- DEEPGEMM_BLACKWELL = ENABLE_JIT_DEEPGEMM and _is_blackwell_arch()
23
+ DEEPGEMM_BLACKWELL = ENABLE_JIT_DEEPGEMM and is_blackwell()
32
24
  DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_BLACKWELL