sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (408) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +330 -156
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +8 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +134 -23
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +70 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +66 -66
  69. sglang/srt/entrypoints/grpc_server.py +431 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +120 -8
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +42 -4
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +3 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +18 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/utils.py +2 -2
  93. sglang/srt/grpc/compile_proto.py +3 -3
  94. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  95. sglang/srt/grpc/health_servicer.py +189 -0
  96. sglang/srt/grpc/scheduler_launcher.py +181 -0
  97. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  98. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  99. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  100. sglang/srt/layers/activation.py +4 -1
  101. sglang/srt/layers/attention/aiter_backend.py +3 -3
  102. sglang/srt/layers/attention/ascend_backend.py +17 -1
  103. sglang/srt/layers/attention/attention_registry.py +43 -23
  104. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  105. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  106. sglang/srt/layers/attention/fla/chunk.py +0 -1
  107. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  108. sglang/srt/layers/attention/fla/index.py +0 -2
  109. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  110. sglang/srt/layers/attention/fla/utils.py +0 -3
  111. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  112. sglang/srt/layers/attention/flashattention_backend.py +12 -8
  113. sglang/srt/layers/attention/flashinfer_backend.py +248 -21
  114. sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
  115. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  116. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  117. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  118. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  119. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  121. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  122. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  123. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  124. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  125. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  127. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  128. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  129. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  130. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  131. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  132. sglang/srt/layers/attention/nsa/utils.py +0 -1
  133. sglang/srt/layers/attention/nsa_backend.py +404 -90
  134. sglang/srt/layers/attention/triton_backend.py +208 -34
  135. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  136. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  137. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  138. sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
  139. sglang/srt/layers/attention/utils.py +11 -7
  140. sglang/srt/layers/attention/vision.py +3 -3
  141. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  142. sglang/srt/layers/communicator.py +11 -7
  143. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  146. sglang/srt/layers/dp_attention.py +17 -0
  147. sglang/srt/layers/layernorm.py +45 -15
  148. sglang/srt/layers/linear.py +9 -1
  149. sglang/srt/layers/logits_processor.py +147 -17
  150. sglang/srt/layers/modelopt_utils.py +11 -0
  151. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  152. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  153. sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
  154. sglang/srt/layers/moe/ep_moe/layer.py +119 -397
  155. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  159. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
  160. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  161. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  162. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  163. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  164. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  165. sglang/srt/layers/moe/router.py +51 -15
  166. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  167. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  168. sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
  169. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  170. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  171. sglang/srt/layers/moe/topk.py +3 -2
  172. sglang/srt/layers/moe/utils.py +17 -1
  173. sglang/srt/layers/quantization/__init__.py +2 -53
  174. sglang/srt/layers/quantization/awq.py +183 -6
  175. sglang/srt/layers/quantization/awq_triton.py +29 -0
  176. sglang/srt/layers/quantization/base_config.py +20 -1
  177. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  178. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  179. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  180. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  181. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  183. sglang/srt/layers/quantization/fp8.py +84 -18
  184. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  185. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  186. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  187. sglang/srt/layers/quantization/gptq.py +0 -1
  188. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  189. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  190. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  191. sglang/srt/layers/quantization/mxfp4.py +5 -30
  192. sglang/srt/layers/quantization/petit.py +1 -1
  193. sglang/srt/layers/quantization/quark/quark.py +3 -1
  194. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  195. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  196. sglang/srt/layers/quantization/unquant.py +1 -4
  197. sglang/srt/layers/quantization/utils.py +0 -1
  198. sglang/srt/layers/quantization/w4afp8.py +51 -20
  199. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  200. sglang/srt/layers/radix_attention.py +59 -9
  201. sglang/srt/layers/rotary_embedding.py +673 -16
  202. sglang/srt/layers/sampler.py +36 -16
  203. sglang/srt/layers/sparse_pooler.py +98 -0
  204. sglang/srt/layers/utils.py +0 -1
  205. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  206. sglang/srt/lora/backend/triton_backend.py +0 -1
  207. sglang/srt/lora/eviction_policy.py +139 -0
  208. sglang/srt/lora/lora_manager.py +24 -9
  209. sglang/srt/lora/lora_registry.py +1 -1
  210. sglang/srt/lora/mem_pool.py +40 -16
  211. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  212. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  213. sglang/srt/managers/cache_controller.py +48 -17
  214. sglang/srt/managers/data_parallel_controller.py +146 -42
  215. sglang/srt/managers/detokenizer_manager.py +40 -13
  216. sglang/srt/managers/io_struct.py +66 -16
  217. sglang/srt/managers/mm_utils.py +20 -18
  218. sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
  219. sglang/srt/managers/overlap_utils.py +96 -19
  220. sglang/srt/managers/schedule_batch.py +241 -511
  221. sglang/srt/managers/schedule_policy.py +15 -2
  222. sglang/srt/managers/scheduler.py +399 -499
  223. sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
  224. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  225. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  226. sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
  227. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  228. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  229. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  230. sglang/srt/managers/tokenizer_manager.py +378 -90
  231. sglang/srt/managers/tp_worker.py +212 -161
  232. sglang/srt/managers/utils.py +78 -2
  233. sglang/srt/mem_cache/allocator.py +7 -2
  234. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  235. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  236. sglang/srt/mem_cache/chunk_cache.py +13 -2
  237. sglang/srt/mem_cache/common.py +480 -0
  238. sglang/srt/mem_cache/evict_policy.py +16 -1
  239. sglang/srt/mem_cache/hicache_storage.py +4 -1
  240. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  241. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  242. sglang/srt/mem_cache/memory_pool.py +435 -219
  243. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  244. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  245. sglang/srt/mem_cache/radix_cache.py +53 -19
  246. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  247. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  249. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  250. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  251. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  252. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  253. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  254. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  255. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  256. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  257. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  258. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  259. sglang/srt/metrics/collector.py +31 -0
  260. sglang/srt/metrics/func_timer.py +1 -1
  261. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  262. sglang/srt/model_executor/forward_batch_info.py +28 -23
  263. sglang/srt/model_executor/model_runner.py +379 -139
  264. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  265. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  266. sglang/srt/model_loader/__init__.py +1 -1
  267. sglang/srt/model_loader/loader.py +424 -27
  268. sglang/srt/model_loader/utils.py +0 -1
  269. sglang/srt/model_loader/weight_utils.py +47 -28
  270. sglang/srt/models/apertus.py +2 -3
  271. sglang/srt/models/arcee.py +2 -2
  272. sglang/srt/models/bailing_moe.py +13 -52
  273. sglang/srt/models/bailing_moe_nextn.py +3 -4
  274. sglang/srt/models/bert.py +1 -1
  275. sglang/srt/models/deepseek_nextn.py +19 -3
  276. sglang/srt/models/deepseek_ocr.py +1516 -0
  277. sglang/srt/models/deepseek_v2.py +273 -98
  278. sglang/srt/models/dots_ocr.py +0 -2
  279. sglang/srt/models/dots_vlm.py +0 -1
  280. sglang/srt/models/dots_vlm_vit.py +1 -1
  281. sglang/srt/models/falcon_h1.py +13 -19
  282. sglang/srt/models/gemma3_mm.py +16 -0
  283. sglang/srt/models/gemma3n_mm.py +1 -2
  284. sglang/srt/models/glm4_moe.py +14 -37
  285. sglang/srt/models/glm4_moe_nextn.py +2 -2
  286. sglang/srt/models/glm4v.py +2 -1
  287. sglang/srt/models/glm4v_moe.py +5 -5
  288. sglang/srt/models/gpt_oss.py +5 -5
  289. sglang/srt/models/grok.py +10 -23
  290. sglang/srt/models/hunyuan.py +2 -7
  291. sglang/srt/models/interns1.py +0 -1
  292. sglang/srt/models/kimi_vl.py +1 -7
  293. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  294. sglang/srt/models/llama.py +2 -2
  295. sglang/srt/models/llama_eagle3.py +1 -1
  296. sglang/srt/models/longcat_flash.py +5 -22
  297. sglang/srt/models/longcat_flash_nextn.py +3 -14
  298. sglang/srt/models/mimo.py +2 -13
  299. sglang/srt/models/mimo_mtp.py +1 -2
  300. sglang/srt/models/minicpmo.py +7 -5
  301. sglang/srt/models/mixtral.py +1 -4
  302. sglang/srt/models/mllama.py +1 -1
  303. sglang/srt/models/mllama4.py +13 -3
  304. sglang/srt/models/nemotron_h.py +511 -0
  305. sglang/srt/models/olmo2.py +31 -4
  306. sglang/srt/models/opt.py +5 -5
  307. sglang/srt/models/phi.py +1 -1
  308. sglang/srt/models/phi4mm.py +1 -1
  309. sglang/srt/models/phimoe.py +0 -1
  310. sglang/srt/models/pixtral.py +0 -3
  311. sglang/srt/models/points_v15_chat.py +186 -0
  312. sglang/srt/models/qwen.py +0 -1
  313. sglang/srt/models/qwen2_5_vl.py +3 -3
  314. sglang/srt/models/qwen2_audio.py +2 -15
  315. sglang/srt/models/qwen2_moe.py +15 -12
  316. sglang/srt/models/qwen2_vl.py +5 -2
  317. sglang/srt/models/qwen3_moe.py +19 -35
  318. sglang/srt/models/qwen3_next.py +7 -12
  319. sglang/srt/models/qwen3_next_mtp.py +3 -4
  320. sglang/srt/models/qwen3_omni_moe.py +661 -0
  321. sglang/srt/models/qwen3_vl.py +37 -33
  322. sglang/srt/models/qwen3_vl_moe.py +57 -185
  323. sglang/srt/models/roberta.py +55 -3
  324. sglang/srt/models/sarashina2_vision.py +0 -1
  325. sglang/srt/models/step3_vl.py +3 -5
  326. sglang/srt/models/utils.py +11 -1
  327. sglang/srt/multimodal/processors/base_processor.py +6 -2
  328. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  329. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  330. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  331. sglang/srt/multimodal/processors/glm4v.py +1 -5
  332. sglang/srt/multimodal/processors/internvl.py +0 -2
  333. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  334. sglang/srt/multimodal/processors/mllama4.py +0 -8
  335. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  336. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  337. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  338. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  339. sglang/srt/parser/conversation.py +41 -0
  340. sglang/srt/parser/reasoning_parser.py +0 -1
  341. sglang/srt/sampling/custom_logit_processor.py +77 -2
  342. sglang/srt/sampling/sampling_batch_info.py +17 -22
  343. sglang/srt/sampling/sampling_params.py +70 -2
  344. sglang/srt/server_args.py +577 -73
  345. sglang/srt/server_args_config_parser.py +1 -1
  346. sglang/srt/single_batch_overlap.py +38 -28
  347. sglang/srt/speculative/base_spec_worker.py +34 -0
  348. sglang/srt/speculative/draft_utils.py +226 -0
  349. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  350. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  351. sglang/srt/speculative/eagle_info.py +57 -18
  352. sglang/srt/speculative/eagle_info_v2.py +458 -0
  353. sglang/srt/speculative/eagle_utils.py +138 -0
  354. sglang/srt/speculative/eagle_worker.py +83 -280
  355. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  356. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  357. sglang/srt/speculative/ngram_worker.py +12 -11
  358. sglang/srt/speculative/spec_info.py +2 -0
  359. sglang/srt/speculative/spec_utils.py +38 -3
  360. sglang/srt/speculative/standalone_worker.py +4 -14
  361. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  362. sglang/srt/two_batch_overlap.py +28 -14
  363. sglang/srt/utils/__init__.py +1 -1
  364. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  365. sglang/srt/utils/common.py +192 -47
  366. sglang/srt/utils/hf_transformers_utils.py +40 -17
  367. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  368. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  369. sglang/srt/utils/profile_merger.py +199 -0
  370. sglang/test/attention/test_flashattn_backend.py +1 -1
  371. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  372. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  373. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  374. sglang/test/few_shot_gsm8k_engine.py +2 -4
  375. sglang/test/kit_matched_stop.py +157 -0
  376. sglang/test/longbench_v2/__init__.py +1 -0
  377. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  378. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  379. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  380. sglang/test/run_eval.py +41 -0
  381. sglang/test/runners.py +2 -0
  382. sglang/test/send_one.py +42 -7
  383. sglang/test/simple_eval_common.py +3 -0
  384. sglang/test/simple_eval_gpqa.py +0 -1
  385. sglang/test/simple_eval_humaneval.py +0 -3
  386. sglang/test/simple_eval_longbench_v2.py +344 -0
  387. sglang/test/test_block_fp8.py +1 -2
  388. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  389. sglang/test/test_cutlass_moe.py +1 -2
  390. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  391. sglang/test/test_deterministic.py +232 -99
  392. sglang/test/test_deterministic_utils.py +73 -0
  393. sglang/test/test_disaggregation_utils.py +81 -0
  394. sglang/test/test_marlin_moe.py +0 -1
  395. sglang/test/test_utils.py +85 -20
  396. sglang/version.py +1 -1
  397. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
  398. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
  399. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  400. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  401. sglang/srt/speculative/build_eagle_tree.py +0 -427
  402. sglang/test/test_block_fp8_ep.py +0 -358
  403. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  404. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  405. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  406. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,339 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import logging
5
+ from typing import Callable, Optional
6
+
7
+ import torch
8
+ from compressed_tensors.quantization import ActivationOrdering
9
+
10
+ # yapf conflicts with isort for this block
11
+ # yapf: disable
12
+ from sglang.srt.layers.parameter import (
13
+ BasevLLMParameter,
14
+ ChannelQuantScaleParameter,
15
+ GroupQuantScaleParameter,
16
+ PackedColumnParameter,
17
+ PackedvLLMParameter,
18
+ RowvLLMParameter,
19
+ permute_param_layout_,
20
+ )
21
+ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
22
+ CompressedTensorsScheme,
23
+ )
24
+ from sglang.srt.layers.quantization.marlin_utils import (
25
+ MarlinLinearLayerConfig,
26
+ apply_gptq_marlin_linear,
27
+ check_marlin_supports_shape,
28
+ marlin_is_k_full,
29
+ marlin_make_empty_g_idx,
30
+ marlin_make_workspace,
31
+ marlin_permute_scales,
32
+ marlin_repeat_scales_on_all_ranks,
33
+ marlin_sort_g_idx,
34
+ marlin_zero_points,
35
+ )
36
+ from sglang.srt.layers.quantization.utils import (
37
+ get_scalar_types,
38
+ replace_parameter,
39
+ unpack_cols,
40
+ )
41
+ from sglang.srt.utils import is_cuda
42
+
43
+ _is_cuda = is_cuda()
44
+
45
+ if _is_cuda:
46
+ from sgl_kernel import gptq_marlin_repack
47
+
48
+
49
+ ScalarType, scalar_types = get_scalar_types()
50
+
51
+ logger = logging.getLogger(__name__)
52
+
53
+ __all__ = ["CompressedTensorsWNA16"]
54
+ WNA16_SUPPORTED_TYPES_MAP = {
55
+ 4: scalar_types.uint4b8,
56
+ 8: scalar_types.uint8b128
57
+ }
58
+ WNA16_ZP_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4, 8: scalar_types.uint8}
59
+ WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
60
+
61
+
62
+ class CompressedTensorsWNA16(CompressedTensorsScheme):
63
+ _kernel_backends_being_used: set[str] = set()
64
+
65
+ def __init__(self,
66
+ strategy: str,
67
+ num_bits: int,
68
+ group_size: Optional[int] = None,
69
+ symmetric: Optional[bool] = True,
70
+ actorder: Optional[ActivationOrdering] = None):
71
+
72
+ self.pack_factor = 32 // num_bits
73
+ self.strategy = strategy
74
+ self.symmetric = symmetric
75
+ self.group_size = -1 if group_size is None else group_size
76
+ self.has_g_idx = actorder == ActivationOrdering.GROUP
77
+
78
+ if self.group_size == -1 and self.strategy != "channel":
79
+ raise ValueError("Marlin kernels require group quantization or "
80
+ "channelwise quantization, but found no group "
81
+ "size and strategy is not channelwise.")
82
+
83
+ if num_bits not in WNA16_SUPPORTED_TYPES_MAP:
84
+ raise ValueError(
85
+ f"Unsupported num_bits = {num_bits}. "
86
+ f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}")
87
+
88
+ self.quant_type = (WNA16_ZP_SUPPORTED_TYPES_MAP[num_bits]
89
+ if not self.symmetric else
90
+ WNA16_SUPPORTED_TYPES_MAP[num_bits])
91
+
92
+ @classmethod
93
+ def get_min_capability(cls) -> int:
94
+ # ampere and up
95
+ return 80
96
+
97
+ def create_weights(self, layer: torch.nn.Module, output_size: int,
98
+ input_size: int, output_partition_sizes: list[int],
99
+ input_size_per_partition: int,
100
+ params_dtype: torch.dtype, weight_loader: Callable,
101
+ **kwargs):
102
+
103
+ output_size_per_partition = sum(output_partition_sizes)
104
+
105
+ self.kernel_config = MarlinLinearLayerConfig(
106
+ full_weight_shape=(input_size, output_size),
107
+ partition_weight_shape=(
108
+ input_size_per_partition,
109
+ output_size_per_partition,
110
+ ),
111
+ weight_type=self.quant_type,
112
+ act_type=params_dtype,
113
+ group_size=self.group_size,
114
+ zero_points=not self.symmetric,
115
+ has_g_idx=self.has_g_idx
116
+ )
117
+
118
+ # If group_size is -1, we are in channelwise case.
119
+ group_size = self.group_size if self.group_size != -1 else input_size
120
+ row_parallel = (input_size != input_size_per_partition)
121
+ partition_scales = not marlin_repeat_scales_on_all_ranks(
122
+ self.has_g_idx, self.group_size, row_parallel)
123
+
124
+ scales_and_zp_size = input_size // group_size
125
+
126
+ if partition_scales:
127
+ assert input_size_per_partition % group_size == 0
128
+ scales_and_zp_size = input_size_per_partition // group_size
129
+
130
+ weight = PackedvLLMParameter(input_dim=1,
131
+ output_dim=0,
132
+ weight_loader=weight_loader,
133
+ packed_factor=self.pack_factor,
134
+ packed_dim=1,
135
+ data=torch.empty(
136
+ output_size_per_partition,
137
+ input_size_per_partition //
138
+ self.pack_factor,
139
+ dtype=torch.int32,
140
+ ))
141
+
142
+ weight_scale_args = {
143
+ "weight_loader":
144
+ weight_loader,
145
+ "data":
146
+ torch.empty(
147
+ output_size_per_partition,
148
+ scales_and_zp_size,
149
+ dtype=params_dtype,
150
+ )
151
+ }
152
+
153
+ zeros_args = {
154
+ "weight_loader":
155
+ weight_loader,
156
+ "data":
157
+ torch.zeros(
158
+ output_size_per_partition // self.pack_factor,
159
+ scales_and_zp_size,
160
+ dtype=torch.int32,
161
+ )
162
+ }
163
+
164
+ if not partition_scales:
165
+ weight_scale = ChannelQuantScaleParameter(output_dim=0,
166
+ **weight_scale_args)
167
+
168
+ if not self.symmetric:
169
+ qzeros = PackedColumnParameter(output_dim=0,
170
+ packed_dim=0,
171
+ packed_factor=self.pack_factor,
172
+ **zeros_args)
173
+ else:
174
+ weight_scale = GroupQuantScaleParameter(output_dim=0,
175
+ input_dim=1,
176
+ **weight_scale_args)
177
+ if not self.symmetric:
178
+ qzeros = PackedvLLMParameter(input_dim=1,
179
+ output_dim=0,
180
+ packed_dim=0,
181
+ packed_factor=self.pack_factor,
182
+ **zeros_args)
183
+
184
+ # A 2D array defining the original shape of the weights
185
+ # before packing
186
+ weight_shape = BasevLLMParameter(data=torch.empty(2,
187
+ dtype=torch.int64),
188
+ weight_loader=weight_loader)
189
+
190
+ layer.register_parameter("weight_packed", weight)
191
+ layer.register_parameter("weight_scale", weight_scale)
192
+ layer.register_parameter("weight_shape", weight_shape)
193
+
194
+ if not self.symmetric:
195
+ layer.register_parameter("weight_zero_point", qzeros)
196
+
197
+ # group index (for activation reordering)
198
+ if self.has_g_idx:
199
+ weight_g_idx = RowvLLMParameter(data=torch.empty(
200
+ input_size_per_partition,
201
+ dtype=torch.int32,
202
+ ),
203
+ input_dim=0,
204
+ weight_loader=weight_loader)
205
+ layer.register_parameter("weight_g_idx", weight_g_idx)
206
+
207
+ # Checkpoints are serialized in compressed-tensors format, which is
208
+ # different from the format the kernel may want. Handle repacking here.
209
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
210
+ # Default names since marlin requires empty parameters for these,
211
+ # TODO: remove this requirement from marlin (allow optional tensors)
212
+ self.w_q_name = "weight_packed"
213
+ self.w_s_name = "weight_scale"
214
+ self.w_zp_name = "weight_zero_point"
215
+ self.w_gidx_name = "weight_g_idx"
216
+
217
+ device = getattr(layer, self.w_q_name).device
218
+ c = self.kernel_config
219
+
220
+ check_marlin_supports_shape(
221
+ c.partition_weight_shape[1], # out_features
222
+ c.partition_weight_shape[0], # in_features
223
+ c.full_weight_shape[0], # in_features
224
+ c.group_size,
225
+ )
226
+
227
+ row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0]
228
+ self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
229
+
230
+ # Allocate marlin workspace.
231
+ self.workspace = marlin_make_workspace(device)
232
+
233
+ def _transform_param(
234
+ layer: torch.nn.Module, name: Optional[str], fn: Callable
235
+ ) -> None:
236
+ if name is not None and getattr(layer, name, None) is not None:
237
+
238
+ old_param = getattr(layer, name)
239
+ new_param = fn(old_param)
240
+ # replace the parameter with torch.nn.Parameter for TorchDynamo
241
+ # compatibility
242
+ replace_parameter(
243
+ layer, name, torch.nn.Parameter(new_param.data, requires_grad=False)
244
+ )
245
+
246
+ def transform_w_q(x):
247
+ assert isinstance(x, BasevLLMParameter)
248
+ permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
249
+ x.data = gptq_marlin_repack(
250
+ x.data.contiguous(),
251
+ perm=layer.g_idx_sort_indices,
252
+ size_k=c.partition_weight_shape[0],
253
+ size_n=c.partition_weight_shape[1],
254
+ num_bits=c.weight_type.size_bits,
255
+ )
256
+ return x
257
+
258
+ def transform_w_s(x):
259
+ assert isinstance(x, BasevLLMParameter)
260
+ permute_param_layout_(x, input_dim=0, output_dim=1)
261
+ x.data = marlin_permute_scales(
262
+ x.data.contiguous(),
263
+ size_k=c.partition_weight_shape[0],
264
+ size_n=c.partition_weight_shape[1],
265
+ group_size=c.group_size,
266
+ )
267
+ return x
268
+
269
+ if c.has_g_idx:
270
+ g_idx, g_idx_sort_indices = marlin_sort_g_idx(
271
+ getattr(layer, self.w_gidx_name)
272
+ )
273
+ _transform_param(layer, self.w_gidx_name, lambda _: g_idx)
274
+ layer.g_idx_sort_indices = g_idx_sort_indices
275
+ else:
276
+ setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
277
+ layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
278
+
279
+ if c.zero_points:
280
+ grouped_k = (
281
+ c.partition_weight_shape[0] // c.group_size if c.group_size != -1 else 1
282
+ )
283
+ _transform_param(
284
+ layer,
285
+ self.w_zp_name,
286
+ lambda x: marlin_zero_points(
287
+ unpack_cols(
288
+ x.t(),
289
+ c.weight_type.size_bits,
290
+ grouped_k,
291
+ c.partition_weight_shape[1],
292
+ ),
293
+ size_k=grouped_k,
294
+ size_n=c.partition_weight_shape[1],
295
+ num_bits=c.weight_type.size_bits,
296
+ ),
297
+ )
298
+ else:
299
+ setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
300
+ _transform_param(layer, self.w_q_name, transform_w_q)
301
+ _transform_param(layer, self.w_s_name, transform_w_s)
302
+
303
+ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
304
+ bias: Optional[torch.Tensor]) -> torch.Tensor:
305
+ c = self.kernel_config
306
+
307
+ def _get_weight_params(
308
+ layer: torch.nn.Module,
309
+ ) -> tuple[
310
+ torch.Tensor, # w_q
311
+ torch.Tensor, # w_s
312
+ Optional[torch.Tensor], # w_zp,
313
+ Optional[torch.Tensor], # w_gidx
314
+ ]:
315
+ return (
316
+ getattr(layer, self.w_q_name),
317
+ getattr(layer, self.w_s_name),
318
+ getattr(layer, self.w_zp_name or "", None),
319
+ getattr(layer, self.w_gidx_name or "", None),
320
+ )
321
+
322
+ w_q, w_s, w_zp, w_gidx = _get_weight_params(layer)
323
+
324
+ # `process_weights_after_loading` will ensure w_zp and w_gidx are not
325
+ # None for marlin
326
+ return apply_gptq_marlin_linear(
327
+ input=x,
328
+ weight=w_q,
329
+ weight_scale=w_s,
330
+ weight_zp=w_zp, # type: ignore
331
+ g_idx=w_gidx, # type: ignore
332
+ g_idx_sort_indices=layer.g_idx_sort_indices,
333
+ workspace=self.workspace,
334
+ wtype=c.weight_type,
335
+ input_size_per_partition=c.partition_weight_shape[0],
336
+ output_size_per_partition=c.partition_weight_shape[1],
337
+ is_k_full=self.is_k_full,
338
+ bias=bias,
339
+ )
@@ -31,8 +31,8 @@ except ImportError:
31
31
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
32
32
  from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
33
33
  from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
34
+ from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmMoeQuantInfo
34
35
  from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
35
- from sglang.srt.layers.moe.token_dispatcher.base import DispatchOutputChecker
36
36
  from sglang.srt.layers.parameter import (
37
37
  BlockQuantScaleParameter,
38
38
  ModelWeightParameter,
@@ -1006,8 +1006,29 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1006
1006
  def create_moe_runner(
1007
1007
  self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
1008
1008
  ):
1009
+
1010
+ from sglang.srt.layers import deep_gemm_wrapper
1011
+ from sglang.srt.layers.moe.utils import (
1012
+ get_moe_a2a_backend,
1013
+ get_moe_runner_backend,
1014
+ )
1015
+
1009
1016
  self.moe_runner_config = moe_runner_config
1010
- self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
1017
+ moe_runner_backend = get_moe_runner_backend()
1018
+
1019
+ if moe_runner_backend.is_auto():
1020
+ if (
1021
+ deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
1022
+ and get_moe_a2a_backend().is_deepep()
1023
+ ):
1024
+ moe_runner_backend = MoeRunnerBackend.DEEP_GEMM
1025
+ else:
1026
+ moe_runner_backend = MoeRunnerBackend.TRITON
1027
+ if moe_runner_backend.is_deep_gemm() or moe_runner_backend.is_triton():
1028
+ self.runner = MoeRunner(moe_runner_backend, moe_runner_config)
1029
+ else:
1030
+ # TODO(cwan): refactor other backends
1031
+ pass
1011
1032
 
1012
1033
  def apply(
1013
1034
  self,
@@ -1087,22 +1108,67 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1087
1108
  )
1088
1109
  return StandardCombineInput(hidden_states=output)
1089
1110
 
1090
- quant_info = TritonMoeQuantInfo(
1091
- w13_weight=layer.w13_weight,
1092
- w2_weight=layer.w2_weight,
1093
- use_fp8_w8a8=True,
1094
- w13_scale=(
1095
- layer.w13_weight_scale_inv
1096
- if self.block_quant
1097
- else layer.w13_weight_scale
1098
- ),
1099
- w2_scale=(
1100
- layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
1101
- ),
1102
- a13_scale=layer.w13_input_scale,
1103
- a2_scale=layer.w2_input_scale,
1104
- block_shape=self.quant_config.weight_block_size,
1105
- )
1111
+ if self.runner.runner_backend.is_deep_gemm():
1112
+
1113
+ w13_weight = layer.w13_weight
1114
+ w2_weight = layer.w2_weight
1115
+
1116
+ if self.block_quant:
1117
+ block_shape = self.quant_config.weight_block_size
1118
+ w13_scale = layer.w13_weight_scale_inv
1119
+ w2_scale = layer.w2_weight_scale_inv
1120
+ else:
1121
+ # Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
1122
+ scale_block_size = 128
1123
+ block_shape = [scale_block_size, scale_block_size]
1124
+ w13_scale_n = (w13_weight.shape[1] - 1) // scale_block_size + 1
1125
+ w13_scale_k = (w13_weight.shape[2] - 1) // scale_block_size + 1
1126
+ w13_scale = (
1127
+ layer.w13_weight_scale.unsqueeze(1)
1128
+ .repeat_interleave(w13_scale_n, dim=1)
1129
+ .unsqueeze(2)
1130
+ .repeat_interleave(w13_scale_k, dim=2)
1131
+ )
1132
+ w2_scale_n = (w2_weight.shape[1] - 1) // scale_block_size + 1
1133
+ w2_scale_k = (w2_weight.shape[2] - 1) // scale_block_size + 1
1134
+ w2_scale = (
1135
+ layer.w2_weight_scale.unsqueeze(1)
1136
+ .repeat_interleave(w2_scale_n, dim=1)
1137
+ .unsqueeze(2)
1138
+ .repeat_interleave(w2_scale_k, dim=2)
1139
+ )
1140
+ quant_info = DeepGemmMoeQuantInfo(
1141
+ w13_weight=w13_weight,
1142
+ w2_weight=w2_weight,
1143
+ use_fp8=True,
1144
+ w13_scale=w13_scale,
1145
+ w2_scale=w2_scale,
1146
+ block_shape=block_shape,
1147
+ )
1148
+ elif self.runner.runner_backend.is_triton():
1149
+ quant_info = TritonMoeQuantInfo(
1150
+ w13_weight=layer.w13_weight,
1151
+ w2_weight=layer.w2_weight,
1152
+ use_fp8_w8a8=True,
1153
+ w13_scale=(
1154
+ layer.w13_weight_scale_inv
1155
+ if self.block_quant
1156
+ else layer.w13_weight_scale
1157
+ ),
1158
+ w2_scale=(
1159
+ layer.w2_weight_scale_inv
1160
+ if self.block_quant
1161
+ else layer.w2_weight_scale
1162
+ ),
1163
+ a13_scale=layer.w13_input_scale,
1164
+ a2_scale=layer.w2_input_scale,
1165
+ block_shape=self.quant_config.weight_block_size,
1166
+ )
1167
+ else:
1168
+ raise NotImplementedError(
1169
+ "Unsupported runner backend: %s" % self.runner.runner_backend
1170
+ )
1171
+
1106
1172
  return self.runner.run(dispatch_output, quant_info)
1107
1173
 
1108
1174
  def apply_with_router_logits(
@@ -23,7 +23,7 @@ import torch
23
23
  import triton
24
24
  import triton.language as tl
25
25
 
26
- from sglang.srt.layers.quantization import deep_gemm_wrapper
26
+ from sglang.srt.layers import deep_gemm_wrapper
27
27
  from sglang.srt.utils import (
28
28
  align,
29
29
  direct_register_custom_op,
@@ -43,11 +43,17 @@ _is_cpu = is_cpu()
43
43
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
44
44
 
45
45
  if _is_cuda:
46
- from sgl_kernel import (
47
- sgl_per_tensor_quant_fp8,
48
- sgl_per_token_group_quant_fp8,
49
- sgl_per_token_quant_fp8,
50
- )
46
+ from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8
47
+
48
+ # Temporary
49
+ try:
50
+ from sgl_kernel import sgl_per_token_group_quant_8bit
51
+
52
+ enable_sgl_per_token_group_quant_8bit = True
53
+ except ImportError:
54
+ from sgl_kernel import sgl_per_token_group_quant_fp8
55
+
56
+ enable_sgl_per_token_group_quant_8bit = False
51
57
 
52
58
  if _is_hip:
53
59
  if _use_aiter:
@@ -61,7 +67,7 @@ if _is_hip:
61
67
  raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
62
68
  else:
63
69
  try:
64
- import vllm._C
70
+ import vllm._C # noqa: F401
65
71
  except ImportError:
66
72
  raise ImportError("vllm is required when SGLANG_USE_AITER is set to False")
67
73
 
@@ -477,6 +483,7 @@ def sglang_per_token_group_quant_fp8(
477
483
  scale_ue8m0: bool = False,
478
484
  fuse_silu_and_mul: bool = False,
479
485
  masked_m: Optional[torch.Tensor] = None,
486
+ enable_v2: Optional[bool] = None,
480
487
  ):
481
488
  assert (
482
489
  x.shape[-1] % group_size == 0
@@ -496,9 +503,26 @@ def sglang_per_token_group_quant_fp8(
496
503
  )
497
504
 
498
505
  if x.shape[0] > 0:
499
- sgl_per_token_group_quant_fp8(
500
- x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
501
- )
506
+ # Temporary
507
+ if enable_sgl_per_token_group_quant_8bit:
508
+ sgl_per_token_group_quant_8bit(
509
+ x,
510
+ x_q,
511
+ x_s,
512
+ group_size,
513
+ eps,
514
+ fp8_min,
515
+ fp8_max,
516
+ scale_ue8m0,
517
+ fuse_silu_and_mul,
518
+ masked_m,
519
+ enable_v2=enable_v2,
520
+ )
521
+ else:
522
+ assert not enable_v2
523
+ sgl_per_token_group_quant_fp8(
524
+ x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
525
+ )
502
526
 
503
527
  return x_q, x_s
504
528
 
@@ -514,6 +538,7 @@ def sglang_per_token_group_quant_8bit(
514
538
  scale_ue8m0: bool = False,
515
539
  fuse_silu_and_mul: bool = False,
516
540
  masked_m: Optional[torch.Tensor] = None,
541
+ enable_v2: Optional[bool] = None,
517
542
  ):
518
543
  from sglang.srt.layers.quantization.int8_kernel import (
519
544
  sglang_per_token_group_quant_int8,
@@ -529,6 +554,7 @@ def sglang_per_token_group_quant_8bit(
529
554
  group_size=group_size,
530
555
  eps=eps,
531
556
  dtype=dst_dtype,
557
+ enable_v2=enable_v2,
532
558
  )
533
559
 
534
560
  return sglang_per_token_group_quant_fp8(
@@ -540,6 +566,7 @@ def sglang_per_token_group_quant_8bit(
540
566
  scale_ue8m0=scale_ue8m0,
541
567
  fuse_silu_and_mul=fuse_silu_and_mul,
542
568
  masked_m=masked_m,
569
+ enable_v2=enable_v2,
543
570
  )
544
571
 
545
572
 
@@ -1804,3 +1831,21 @@ def triton_scaled_mm(
1804
1831
  )
1805
1832
 
1806
1833
  return result.to(out_dtype)
1834
+
1835
+
1836
+ if _is_cuda:
1837
+ if enable_sgl_per_token_group_quant_8bit:
1838
+
1839
+ @torch.library.register_fake("sgl_kernel::sgl_per_token_group_quant_8bit")
1840
+ def _(
1841
+ input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
1842
+ ):
1843
+ return
1844
+
1845
+ else:
1846
+
1847
+ @torch.library.register_fake("sgl_kernel::sgl_per_token_group_quant_fp8")
1848
+ def _(
1849
+ input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
1850
+ ):
1851
+ return
@@ -2,11 +2,10 @@ from typing import Callable, List, Optional, Tuple
2
2
 
3
3
  import torch
4
4
 
5
- from sglang.srt import offloader
6
- from sglang.srt.layers.quantization import deep_gemm_wrapper
5
+ from sglang.srt.layers import deep_gemm_wrapper
7
6
  from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
8
7
  from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
9
- from sglang.srt.utils import is_sm100_supported
8
+ from sglang.srt.utils import ceil_div, is_sm100_supported, offloader
10
9
 
11
10
  try:
12
11
  from vllm import _custom_ops as ops
@@ -29,7 +28,6 @@ from sglang.srt.layers.quantization.fp8_kernel import (
29
28
  )
30
29
  from sglang.srt.utils import (
31
30
  align,
32
- ceil_div,
33
31
  get_bool_env_var,
34
32
  get_cuda_version,
35
33
  get_device_capability,
@@ -443,23 +441,53 @@ def _requant_weight_ue8m0(
443
441
  torch.bfloat16,
444
442
  )
445
443
 
444
+ out_w, out_s = quant_weight_ue8m0(
445
+ weight_dequant=weight_dequant,
446
+ weight_block_size=weight_block_size,
447
+ )
448
+
449
+ out_s = _transform_scale_ue8m0(out_s, mn=out_w.shape[-2])
450
+
451
+ return out_w, out_s
452
+
453
+
454
+ def quant_weight_ue8m0(
455
+ weight_dequant: torch.Tensor,
456
+ weight_block_size: List[int],
457
+ ):
458
+ assert weight_block_size == [128, 128]
459
+ assert (
460
+ weight_dequant.dtype == torch.bfloat16
461
+ ), f"{weight_dequant.dtype=} {weight_dequant.shape=}"
462
+
463
+ *batch_dims, n, k = weight_dequant.shape
464
+
446
465
  weight_dequant_flat = weight_dequant.view((-1, k))
447
466
  out_w_flat, out_s_flat = per_block_cast_to_fp8(weight_dequant_flat)
448
467
 
449
- out_w = out_w_flat.view(weight.shape)
450
- out_s = out_s_flat.view(weight_scale_inv.shape)
468
+ out_w = out_w_flat.view((*batch_dims, n, k))
469
+ out_s = out_s_flat.view(
470
+ (
471
+ *batch_dims,
472
+ ceil_div(n, weight_block_size[0]),
473
+ ceil_div(k, weight_block_size[1]),
474
+ )
475
+ )
476
+
477
+ return out_w, out_s
478
+
451
479
 
452
- # NOTE copy and modified from DeepGEMM
453
- def _transform_scale(sf, mn: int):
454
- import deep_gemm.utils.layout
480
+ def transform_scale_ue8m0_inplace(param, mn):
481
+ param.data = _transform_scale_ue8m0(param.data, mn=mn)
455
482
 
456
- sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128)
457
- sf = deep_gemm.utils.layout.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
458
- return sf
459
483
 
460
- out_s = _transform_scale(out_s, mn=out_w.shape[-2])
484
+ # NOTE copy and modified from DeepGEMM
485
+ def _transform_scale_ue8m0(sf, mn):
486
+ import deep_gemm.utils.layout
461
487
 
462
- return out_w, out_s
488
+ sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128)
489
+ sf = deep_gemm.utils.layout.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
490
+ return sf
463
491
 
464
492
 
465
493
  # COPIED FROM DeepGEMM
@@ -2,7 +2,7 @@
2
2
  from __future__ import annotations
3
3
 
4
4
  import logging
5
- from typing import Any, Optional
5
+ from typing import Any, List, Optional
6
6
 
7
7
  import torch
8
8
  from torch.nn import Module
@@ -11,7 +11,6 @@ from torch.nn.parameter import Parameter
11
11
  from sglang.srt.layers.linear import LinearBase
12
12
  from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
13
13
  from sglang.srt.layers.quantization.base_config import (
14
- FusedMoEMethodBase,
15
14
  LinearMethodBase,
16
15
  QuantizationConfig,
17
16
  QuantizeMethodBase,
@@ -28,7 +27,7 @@ from sglang.srt.layers.quantization.marlin_utils_fp8 import (
28
27
  prepare_fp8_layer_for_marlin,
29
28
  )
30
29
  from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
31
- from sglang.srt.layers.quantization.utils import is_layer_skipped, replace_parameter
30
+ from sglang.srt.layers.quantization.utils import is_layer_skipped
32
31
  from sglang.srt.utils import get_bool_env_var, is_cuda
33
32
 
34
33
  _is_cuda = is_cuda()