sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,566 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Adapted from: https://github.com/vllm-project/vllm/blob/ab3e80042eac24dd362408e6d63ad98768046359/vllm/model_executor/layers/quantization/gguf.py
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Any, List, Optional
8
+
9
+ import gguf
10
+ import torch
11
+ from gguf import GGMLQuantizationType as WeightType
12
+ from torch.nn.parameter import Parameter, UninitializedParameter
13
+
14
+ from sglang.srt.layers.linear import LinearBase
15
+ from sglang.srt.layers.moe import MoeRunnerConfig
16
+ from sglang.srt.layers.quantization.base_config import (
17
+ FusedMoEMethodBase,
18
+ LinearMethodBase,
19
+ QuantizationConfig,
20
+ QuantizeMethodBase,
21
+ )
22
+ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
23
+ from sglang.srt.utils import is_cuda, is_hip, is_xpu, set_weight_attrs
24
+
25
+ if TYPE_CHECKING:
26
+ from sglang.srt.layers.moe.token_dispatcher import (
27
+ CombineInput,
28
+ StandardDispatchOutput,
29
+ )
30
+
31
+ _is_cuda = is_cuda()
32
+ _is_hip = is_hip()
33
+ _is_xpu = is_xpu()
34
+
35
+ if _is_cuda:
36
+ from sgl_kernel import gelu_and_mul, moe_align_block_size, moe_sum, silu_and_mul
37
+ from sgl_kernel.quantization import (
38
+ ggml_dequantize,
39
+ ggml_moe_a8,
40
+ ggml_moe_a8_vec,
41
+ ggml_moe_get_block_size,
42
+ ggml_mul_mat_a8,
43
+ ggml_mul_mat_vec_a8,
44
+ )
45
+ else:
46
+ warnings.warn(f"Only CUDA support GGUF q uantization currently.")
47
+
48
+ logger = logging.getLogger(__name__)
49
+
50
+
51
+ class GGUFConfig(QuantizationConfig):
52
+ """Config class for GGUF."""
53
+
54
+ def __init__(self, modules_to_not_convert: list[str] | None = None) -> None:
55
+ super().__init__()
56
+ self.modules_to_not_convert = modules_to_not_convert or []
57
+
58
+ def __repr__(self) -> str:
59
+ return "GGUFConfig()"
60
+
61
+ def get_scaled_act_names(self) -> List[str]:
62
+ return []
63
+
64
+ def get_name(self) -> "str":
65
+ return "gguf"
66
+
67
+ def get_supported_act_dtypes(self) -> list[torch.dtype]:
68
+ return [torch.half, torch.bfloat16, torch.float32]
69
+
70
+ @classmethod
71
+ def get_min_capability(cls) -> int:
72
+ return 60
73
+
74
+ @classmethod
75
+ def get_config_filenames(cls) -> list[str]:
76
+ return [] # no extra configs.
77
+
78
+ @classmethod
79
+ def from_config(cls, config: dict[str, Any]) -> "GGUFConfig":
80
+ modules_to_not_convert = cls.get_from_keys_or(
81
+ config, ["modules_to_not_convert"], None
82
+ )
83
+ return cls(modules_to_not_convert)
84
+
85
+ def get_quant_method(
86
+ self, layer: torch.nn.Module, prefix: str
87
+ ) -> Optional["QuantizeMethodBase"]:
88
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
89
+ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
90
+
91
+ if isinstance(layer, LinearBase):
92
+ if is_layer_skipped_gguf(prefix, self.modules_to_not_convert):
93
+ return UnquantizedLinearMethod()
94
+ return GGUFLinearMethod(self)
95
+ elif isinstance(layer, VocabParallelEmbedding):
96
+ return GGUFEmbeddingMethod(self)
97
+ elif isinstance(layer, FusedMoE):
98
+ return GGUFMoEMethod(self)
99
+ return None
100
+
101
+
102
+ def is_layer_skipped_gguf(prefix: str, modules_to_not_convert: list[str]):
103
+ return any(module_name in prefix for module_name in modules_to_not_convert)
104
+
105
+
106
+ UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}
107
+ STANDARD_QUANT_TYPES = {
108
+ WeightType.Q4_0,
109
+ WeightType.Q4_1,
110
+ WeightType.Q5_0,
111
+ WeightType.Q5_1,
112
+ WeightType.Q8_0,
113
+ WeightType.Q8_1,
114
+ }
115
+ KQUANT_TYPES = {
116
+ WeightType.Q2_K,
117
+ WeightType.Q3_K,
118
+ WeightType.Q4_K,
119
+ WeightType.Q5_K,
120
+ WeightType.Q6_K,
121
+ }
122
+ IMATRIX_QUANT_TYPES = {
123
+ WeightType.IQ1_M,
124
+ WeightType.IQ1_S,
125
+ WeightType.IQ2_XXS,
126
+ WeightType.IQ2_XS,
127
+ WeightType.IQ2_S,
128
+ WeightType.IQ3_XXS,
129
+ WeightType.IQ3_S,
130
+ WeightType.IQ4_XS,
131
+ WeightType.IQ4_NL,
132
+ }
133
+ # TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
134
+ # Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
135
+ # MMQ kernel for I-Matrix quantization.
136
+ DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
137
+ MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
138
+ MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
139
+
140
+
141
+ def fused_mul_mat_gguf(
142
+ x: torch.Tensor, qweight: torch.Tensor, qweight_type: int
143
+ ) -> torch.Tensor:
144
+ if qweight_type in IMATRIX_QUANT_TYPES:
145
+ mmvq_safe = 8 if qweight.shape[0] > 5120 else 16
146
+ else:
147
+ mmvq_safe = 2 if qweight.shape[0] > 5120 else 6
148
+ # HACK: when doing chunked prefill we don't generate output tokens
149
+ # so input to logits generator is empty which causes invalid parameter
150
+ if x.shape[0] == 0:
151
+ return torch.empty(x.shape[0], qweight.shape[0], dtype=x.dtype, device=x.device)
152
+ # there is no need to call any kernel for fp16/bf16
153
+ if qweight_type in UNQUANTIZED_TYPES:
154
+ return x @ qweight.T
155
+ # enable MMVQ in contiguous batching with batch_size=1
156
+ if x.shape[0] <= mmvq_safe and qweight_type in MMVQ_QUANT_TYPES:
157
+ y = ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
158
+ # Use MMQ Kernel if it's available (standard + k-quants)
159
+ elif qweight_type in MMQ_QUANT_TYPES:
160
+ y = ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
161
+ # If there is no available MMQ kernel, fallback to dequantize
162
+ elif qweight_type in DEQUANT_TYPES:
163
+ block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
164
+ shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
165
+ weight = ggml_dequantize(qweight, qweight_type, *shape, x.dtype)
166
+ y = x @ weight.T
167
+ else:
168
+ # Raise an error if the quantization type is not supported.
169
+ # Might be useful if llama.cpp adds a new quantization type.
170
+ # Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
171
+ qweight_type = WeightType(qweight_type)
172
+ raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
173
+ return y
174
+
175
+
176
+ def fused_moe_gguf(
177
+ x: torch.Tensor,
178
+ w1: torch.Tensor,
179
+ w2: torch.Tensor,
180
+ topk_weights: torch.Tensor,
181
+ topk_ids: torch.Tensor,
182
+ qweight_type: int,
183
+ qweight_type2: int,
184
+ activation: str,
185
+ ) -> torch.Tensor:
186
+ def act(x: torch.Tensor):
187
+ d = x.shape[-1] // 2
188
+ output_shape = x.shape[:-1] + (d,)
189
+ out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
190
+ if activation == "silu":
191
+ silu_and_mul(out, x)
192
+ elif activation == "gelu":
193
+ gelu_and_mul(out, x)
194
+ else:
195
+ raise ValueError(f"Unsupported activation: {activation}")
196
+ return out
197
+
198
+ out_hidden_states = torch.empty_like(x)
199
+ # unless we decent expert reuse we are better off running moe_vec kernel
200
+ if (
201
+ qweight_type2 in MMQ_QUANT_TYPES
202
+ and qweight_type in MMQ_QUANT_TYPES
203
+ and x.shape[0] > 64
204
+ ):
205
+ num_tokens, _ = x.shape
206
+ E, N, _ = w1.shape
207
+ top_k = topk_ids.shape[1]
208
+ BLOCK_SIZE = ggml_moe_get_block_size(qweight_type)
209
+
210
+ sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
211
+ topk_ids, BLOCK_SIZE, E
212
+ )
213
+ out = ggml_moe_a8(
214
+ x,
215
+ w1,
216
+ sorted_token_ids,
217
+ expert_ids,
218
+ num_tokens_post_padded,
219
+ qweight_type,
220
+ N,
221
+ top_k,
222
+ num_tokens,
223
+ )
224
+ out = act(out)
225
+ out = ggml_moe_a8(
226
+ out,
227
+ w2,
228
+ sorted_token_ids,
229
+ expert_ids,
230
+ num_tokens_post_padded,
231
+ qweight_type2,
232
+ w2.shape[1],
233
+ 1,
234
+ num_tokens * top_k,
235
+ )
236
+ out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
237
+ topk_weights.view(num_tokens, top_k, 1)
238
+ )
239
+ # TODO(FlamingoPg): maybe we can use moe_sum_reduce here?
240
+ moe_sum(out, out_hidden_states)
241
+ elif qweight_type2 in MMVQ_QUANT_TYPES and qweight_type in MMVQ_QUANT_TYPES:
242
+ num_tokens, _ = x.shape
243
+ E, N, _ = w1.shape
244
+ top_k = topk_ids.shape[1]
245
+
246
+ out = ggml_moe_a8_vec(x, w1, topk_ids, top_k, qweight_type, N, num_tokens)
247
+ out = act(out)
248
+
249
+ out = ggml_moe_a8_vec(
250
+ out, w2, topk_ids, 1, qweight_type2, w2.shape[1], num_tokens * top_k
251
+ )
252
+ out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
253
+ topk_weights.view(num_tokens, top_k, 1)
254
+ )
255
+ moe_sum(out, out_hidden_states)
256
+ else:
257
+ logger.warning_once(
258
+ "There is no support for fast MoE kernel "
259
+ "for current quantization method. "
260
+ "Falling back to slow implementation. "
261
+ )
262
+ for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)):
263
+ inp = x[tok].reshape((1,) + x.shape[1:])
264
+ current_hidden_state = None
265
+ for ww, ii in zip(w, idx):
266
+ expert_up = w1[ii]
267
+
268
+ out = fused_mul_mat_gguf(inp, expert_up, qweight_type)
269
+ out = act(out)
270
+
271
+ expert_down = w2[ii]
272
+ current_state = fused_mul_mat_gguf(
273
+ out, expert_down, qweight_type2
274
+ ).mul_(ww)
275
+ if current_hidden_state is None:
276
+ current_hidden_state = current_state
277
+ else:
278
+ current_hidden_state.add_(current_state)
279
+ out_hidden_states[tok] = current_hidden_state
280
+ return out_hidden_states
281
+
282
+
283
+ def apply_gguf_embedding(
284
+ x: torch.Tensor,
285
+ qweight: torch.Tensor,
286
+ qweight_type: int,
287
+ hidden_size: int,
288
+ dtype: torch.dtype | None = None,
289
+ ) -> torch.Tensor:
290
+ if qweight_type in UNQUANTIZED_TYPES:
291
+ return torch.embedding(qweight, x)
292
+ elif qweight_type in DEQUANT_TYPES:
293
+ block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
294
+ x_flat = x.flatten()
295
+ assert hidden_size == qweight.shape[1] // type_size * block_size
296
+ quant = torch.index_select(qweight, dim=0, index=x_flat)
297
+ dequant = ggml_dequantize(
298
+ quant, qweight_type, hidden_size, x_flat.shape[0], dtype
299
+ )
300
+ return dequant.view(*x.shape, hidden_size)
301
+ else:
302
+ qweight_type = WeightType(qweight_type)
303
+ raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
304
+
305
+
306
+ class GGUFLinearMethod(LinearMethodBase):
307
+ """Linear method for GGUF.
308
+
309
+ Args:
310
+ quant_config: The GGUF quantization config.
311
+ """
312
+
313
+ def __init__(self, quant_config: GGUFConfig):
314
+ self.quant_config = quant_config
315
+
316
+ def create_weights(
317
+ self,
318
+ layer: torch.nn.Module,
319
+ input_size_per_partition: int,
320
+ output_partition_sizes: list[int],
321
+ input_size: int,
322
+ output_size: int,
323
+ params_dtype: torch.dtype,
324
+ **extra_weight_attrs,
325
+ ):
326
+ self.params_dtype = params_dtype
327
+ output_size_per_partition = sum(output_partition_sizes)
328
+
329
+ tensor_shape = (output_size_per_partition, input_size_per_partition)
330
+ qweight = GGUFUninitializedParameter(requires_grad=False)
331
+ set_weight_attrs(
332
+ qweight,
333
+ {
334
+ "input_dim": 1,
335
+ "output_dim": 0,
336
+ "tensor_shape": tensor_shape,
337
+ "is_gguf_weight": True,
338
+ "data_container": [],
339
+ "shard_id": [],
340
+ "shard_id_map": {},
341
+ },
342
+ )
343
+ set_weight_attrs(qweight, extra_weight_attrs)
344
+ layer.register_parameter("qweight", qweight)
345
+
346
+ qweight_type = Parameter(
347
+ torch.empty(len(output_partition_sizes), dtype=torch.uint8),
348
+ requires_grad=False,
349
+ )
350
+ set_weight_attrs(
351
+ qweight_type,
352
+ {
353
+ "is_gguf_weight_type": True,
354
+ "weight_type": 0,
355
+ "shard_weight_type": {},
356
+ "ignore_warning": True,
357
+ },
358
+ )
359
+ set_weight_attrs(qweight_type, extra_weight_attrs)
360
+ layer.register_parameter("qweight_type", qweight_type)
361
+
362
+ def process_weights_after_loading(self, layer: torch.nn.Module):
363
+ qweight_type = layer.qweight_type.weight_type
364
+ if not (qweight_type in UNQUANTIZED_TYPES or qweight_type in DEQUANT_TYPES):
365
+ qweight_type = WeightType(qweight_type)
366
+ raise ValueError(
367
+ f"Unsupported GGUF quantization type {qweight_type} in layer {layer}."
368
+ )
369
+ # For MergedColumnParallelLinear and QKVParallelLinear, we need to
370
+ # materialize the padded weight parameter for CUDA Graph compatibility.
371
+ self._create_padded_weight_param(layer)
372
+
373
+ def _create_padded_weight_param(self, layer: torch.nn.Module):
374
+ """Create padded weight parameter for GGUF MergedLinear layer."""
375
+ qweight = layer.qweight
376
+ shard_id_map = qweight.shard_id_map
377
+ shard_id = qweight.shard_id
378
+ if len(data_container := qweight.data_container) > 1:
379
+ dtype = {data.dtype for data in data_container}
380
+ assert len(dtype) == 1, ValueError(
381
+ f"Data container has mixed dtypes: {dtype}"
382
+ )
383
+ dtype = next(iter(dtype))
384
+ # concat dim0 and pad dim1
385
+ padded_side = max(x.size(1) for x in data_container)
386
+ concat_side = sum(x.size(0) for x in data_container)
387
+ # Pad the quantized weights to dense tensor, and create a map
388
+ # with the location of each shard in the padded tensor.
389
+ padded_data = torch.zeros(
390
+ (concat_side, padded_side), dtype=dtype, device=qweight.device
391
+ )
392
+ # (dim0_start, dim0_end, dim1_size)
393
+ shard_offset_map = dict[str, tuple[int, int, int]]()
394
+ for idx in shard_id:
395
+ id_in_container = shard_id_map[idx]
396
+ start = sum(x.size(0) for x in data_container[:id_in_container])
397
+ end = start + data_container[id_in_container].size(0)
398
+ size = data_container[id_in_container].size(1)
399
+ padded_data[start:end, :size] = data_container[id_in_container]
400
+ shard_offset_map[idx] = (start, end, size)
401
+ qweight.data_container.clear()
402
+ padded_param = Parameter(padded_data, requires_grad=False)
403
+ set_weight_attrs(padded_param, vars(qweight))
404
+ set_weight_attrs(padded_param, {"shard_offset_map": shard_offset_map})
405
+ layer.register_parameter("qweight", padded_param)
406
+
407
+ def apply(
408
+ self,
409
+ layer: torch.nn.Module,
410
+ x: torch.Tensor,
411
+ bias: torch.Tensor | None = None,
412
+ ) -> torch.Tensor:
413
+ shard_id = layer.qweight.shard_id
414
+
415
+ if shard_id:
416
+ # dequantize shard weights respectively
417
+ shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
418
+ qweight = layer.qweight
419
+ result = []
420
+ for idx in shard_id:
421
+ start, end, offset = layer.qweight.shard_offset_map[idx]
422
+ qweight_type = layer.qweight_type.shard_weight_type[idx]
423
+ result.append(
424
+ fused_mul_mat_gguf(
425
+ x, qweight[start:end, :offset].contiguous(), qweight_type
426
+ )
427
+ )
428
+ out = torch.cat(result, axis=1)
429
+ else:
430
+ qweight = layer.qweight
431
+ qweight_type = layer.qweight_type.weight_type
432
+ out = fused_mul_mat_gguf(x, qweight, qweight_type)
433
+ if bias is not None:
434
+ out.add_(bias)
435
+ return out
436
+
437
+
438
+ class GGUFMoEMethod(FusedMoEMethodBase):
439
+ """MoE method for GGUF.
440
+
441
+ Args:
442
+ quant_config: The GGUF quantization config.
443
+ """
444
+
445
+ def __init__(self, quant_config: GGUFConfig):
446
+ self.quant_config = quant_config
447
+
448
+ def create_weights(
449
+ self,
450
+ layer: torch.nn.Module,
451
+ num_experts: int,
452
+ hidden_size: int,
453
+ intermediate_size_per_partition: int,
454
+ params_dtype: torch.dtype,
455
+ **extra_weight_attrs,
456
+ ):
457
+ tensor_shape = (num_experts, 2 * intermediate_size_per_partition, hidden_size)
458
+ # gate up proj
459
+ w13_qweight = GGUFUninitializedParameter(requires_grad=False)
460
+ set_weight_attrs(
461
+ w13_qweight,
462
+ {
463
+ "input_dim": 1,
464
+ "output_dim": 0,
465
+ "tensor_shape": tensor_shape,
466
+ "is_gguf_weight": True,
467
+ "data_container": [],
468
+ },
469
+ )
470
+ set_weight_attrs(w13_qweight, extra_weight_attrs)
471
+ layer.register_parameter("w13_qweight", w13_qweight)
472
+
473
+ w13_qweight_type = Parameter(
474
+ torch.empty(1, dtype=torch.uint8), requires_grad=False
475
+ )
476
+ set_weight_attrs(
477
+ w13_qweight_type,
478
+ {"is_gguf_weight_type": True, "weight_type": 0, "ignore_warning": True},
479
+ )
480
+ set_weight_attrs(w13_qweight_type, extra_weight_attrs)
481
+ layer.register_parameter("w13_qweight_type", w13_qweight_type)
482
+
483
+ tensor_shape = (num_experts, intermediate_size_per_partition, hidden_size)
484
+ # gate down proj
485
+ w2_qweight = GGUFUninitializedParameter(requires_grad=False)
486
+ set_weight_attrs(
487
+ w2_qweight,
488
+ {
489
+ "input_dim": 1,
490
+ "output_dim": 0,
491
+ "tensor_shape": tensor_shape,
492
+ "is_gguf_weight": True,
493
+ "data_container": [],
494
+ },
495
+ )
496
+ set_weight_attrs(w2_qweight, extra_weight_attrs)
497
+ layer.register_parameter("w2_qweight", w2_qweight)
498
+
499
+ w2_qweight_type = Parameter(
500
+ torch.empty(1, dtype=torch.uint8), requires_grad=False
501
+ )
502
+ set_weight_attrs(
503
+ w2_qweight_type,
504
+ {"is_gguf_weight_type": True, "weight_type": 0, "ignore_warning": True},
505
+ )
506
+
507
+ set_weight_attrs(w2_qweight_type, extra_weight_attrs)
508
+ layer.register_parameter("w2_qweight_type", w2_qweight_type)
509
+
510
+ def create_moe_runner(
511
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
512
+ ):
513
+ self.moe_runner_config = moe_runner_config
514
+
515
+ def apply(
516
+ self,
517
+ layer: torch.nn.Module,
518
+ dispatch_output: StandardDispatchOutput,
519
+ ) -> CombineInput:
520
+ assert self.fused_experts is None
521
+
522
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
523
+
524
+ assert (
525
+ self.moe_runner_config.activation == "silu"
526
+ ), "Only SiLU activation is supported."
527
+
528
+ x = dispatch_output.hidden_states
529
+ topk_output = dispatch_output.topk_output
530
+
531
+ moe_runner_config = self.moe_runner_config
532
+
533
+ topk_weights, topk_ids, _ = topk_output
534
+ output = fused_moe_gguf(
535
+ x=x,
536
+ w1=layer.w13_qweight,
537
+ w2=layer.w2_qweight,
538
+ topk_weights=topk_weights,
539
+ topk_ids=topk_ids,
540
+ qweight_type=layer.w13_qweight_type.weight_type,
541
+ qweight_type2=layer.w2_qweight_type.weight_type,
542
+ activation=moe_runner_config.activation,
543
+ )
544
+ return StandardCombineInput(hidden_states=output)
545
+
546
+
547
+ class GGUFEmbeddingMethod(GGUFLinearMethod):
548
+ """Embedding method for GGUF.
549
+
550
+ Args:
551
+ quant_config: The GGUF quantization config.
552
+ """
553
+
554
+ def embedding(self, layer: torch.nn.Module, x: torch.Tensor) -> torch.Tensor:
555
+ qweight = layer.qweight
556
+ qweight_type = layer.qweight_type.weight_type
557
+ hidden_size = qweight.tensor_shape[1]
558
+
559
+ return apply_gguf_embedding(
560
+ x, qweight, qweight_type, hidden_size, dtype=self.params_dtype
561
+ )
562
+
563
+
564
+ class GGUFUninitializedParameter(UninitializedParameter):
565
+ cls_to_become = Parameter
566
+ data_container: list[torch.Tensor]
@@ -199,7 +199,6 @@ class GPTQConfig(QuantizationConfig):
199
199
  self, layer: torch.nn.Module, prefix: str
200
200
  ) -> Optional[LinearMethodBase]:
201
201
  # Delay the import to avoid circular dependency
202
- from sglang.srt.layers.linear import LinearBase
203
202
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
204
203
 
205
204
  if isinstance(layer, FusedMoE):
@@ -12,7 +12,15 @@ from sglang.srt.utils import get_device_name, is_cuda
12
12
 
13
13
  _is_cuda = is_cuda()
14
14
  if _is_cuda:
15
- from sgl_kernel import sgl_per_token_group_quant_int8
15
+ # Temporary
16
+ try:
17
+ from sgl_kernel import sgl_per_token_group_quant_8bit
18
+
19
+ enable_sgl_per_token_group_quant_8bit = True
20
+ except ImportError:
21
+ from sgl_kernel import sgl_per_token_group_quant_int8
22
+
23
+ enable_sgl_per_token_group_quant_8bit = False
16
24
 
17
25
  logger = logging.getLogger(__name__)
18
26
 
@@ -187,6 +195,7 @@ def sglang_per_token_group_quant_int8(
187
195
  group_size: int,
188
196
  eps: float = 1e-10,
189
197
  dtype: torch.dtype = torch.int8,
198
+ enable_v2: Optional[bool] = None,
190
199
  ):
191
200
  assert (
192
201
  x.shape[-1] % group_size == 0
@@ -204,7 +213,14 @@ def sglang_per_token_group_quant_int8(
204
213
  dtype=torch.float32,
205
214
  )
206
215
 
207
- sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
216
+ # Temporary
217
+ if enable_sgl_per_token_group_quant_8bit:
218
+ sgl_per_token_group_quant_8bit(
219
+ x, x_q, x_s, group_size, eps, int8_min, int8_max, enable_v2=enable_v2
220
+ )
221
+ else:
222
+ assert not enable_v2
223
+ sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
208
224
 
209
225
  return x_q, x_s
210
226
 
@@ -4,6 +4,7 @@
4
4
  from __future__ import annotations
5
5
 
6
6
  import logging
7
+ from dataclasses import dataclass
7
8
  from typing import TYPE_CHECKING, Any, Optional
8
9
 
9
10
  import numpy
@@ -57,6 +58,17 @@ MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
57
58
  USE_FP32_REDUCE_DEFAULT = True
58
59
 
59
60
 
61
+ @dataclass
62
+ class MarlinLinearLayerConfig:
63
+ full_weight_shape: tuple[int, int] # [in, out]
64
+ partition_weight_shape: tuple[int, int]
65
+ weight_type: ScalarType
66
+ act_type: torch.dtype
67
+ group_size: int
68
+ zero_points: bool
69
+ has_g_idx: bool
70
+
71
+
60
72
  # For binary size and compile time, we don't support the same types for with and
61
73
  # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
62
74
  # TODO: we may want to move this into the C++ so its closer to the actual impl